9.4 Binary Classification Metrics
Note: This code might take several minutes (or longer) to run.
#Cross-fold validation
<- vfold_cv(ames, v = 5)
ames_folds_binary
#Create Recipe
<-
ames_recipe_binary recipe(formula = under_budget ~ Gr_Liv_Area + Full_Bath + Half_Bath + Lot_Area + Neighborhood + Overall_Cond,
data = ames)
#Set the model and hyperparameters
<-
ames_spec_binary rand_forest(mtry = tune(), trees = tune(), min_n = tune()) %>%
set_mode("classification") %>%
set_engine("ranger")
#Create workflow
<-
ames_workflow_binary workflow() %>%
add_recipe(ames_recipe_binary) %>%
add_model(ames_spec_binary)
#Create metric set of all binary metrics
<-
ames_tune_binary tune_grid(
ames_workflow_binary,metrics =
metric_set(sens,spec,recall,precision,mcc,j_index,f_meas,accuracy,
kap,ppv,npv,bal_accuracy,detection_prevalence),resamples = ames_folds_binary,
grid = grid_regular(
mtry(range = c(2, 6)),
min_n(range = c(2, 20)),
trees(range = c(10,100)),
levels = 10
)
)
#Pick the best model for each metric and pull out the predictions
<-
best_models_binary tibble(
metric_name = c('recall','sens','spec', 'precision','mcc','j_index','f_meas','accuracy',
'kap','ppv','npv','bal_accuracy','detection_prevalence')) %>%
mutate(metric_best = map(metric_name, ~select_best(ames_tune_binary, .x)),
wf_best = map(metric_best, ~finalize_workflow(ames_workflow_binary, .x)),
fit_best = map(wf_best, ~fit(.x, data = ames)),
df_pred = map(fit_best, ~ames %>% bind_cols(predict(.x, new_data = ames)) %>% select(under_budget, .pred_class))) %>%
select(-c(wf_best, fit_best)) %>%
unnest(cols = c(metric_name, metric_best, df_pred))
# Plot!
%>%
best_models_binary mutate(metric_desc = factor(
metric_name, levels = c('recall','sens','spec', 'precision','mcc','j_index','f_meas','accuracy',
'kap','ppv','npv','bal_accuracy','detection_prevalence'),
labels = c('recall\nhow many observations out \nof all positive observations \nhave we classified as positive',
'sens\nhow many observations out \nof all positive observations \nhave we classified as positive',
'spec\nhow many observations out \nof all negative observations \nhave we classified as negative',
'precision\nhow many observations \npredicted as positive are \nin fact positive',
'mcc\ncorrelation between \npredicted classes and ground truth',
'j_index\nbalance between \nsensitivity and specificity',
'f_meas\nbalance between \nprecision and recall',
'accuracy\nhow many observations,\n both positive and negative,\n were correctly classified',
'kap\nhow much better is your model\n over the random classifier\n that predicts based on class frequencies',
'ppv\nhow many observations\n predicted as positive\n are in fact positive',
'npv\nhow many predictions\n out of all negative\n predictions were correct',
'bal_accuracy\nbalance between\n sensitivity and specificity',
'detection_prevalence\nhow many positive\n predictions were correct of\n all the predictions'))) %>%
group_by(metric_desc, under_budget, .pred_class) %>%
summarise(bin_count = n()) %>%
ungroup() %>%
ggplot(aes(x = under_budget, y = .pred_class, fill = bin_count, label = bin_count)) +
scale_fill_binned() +
geom_tile() +
geom_label() +
coord_fixed() +
facet_wrap(~metric_desc, ncol = 2) +
theme_minimal() +
theme(panel.spacing = unit(2, "lines"),
strip.text.x = element_text(size = 8))