8.38 Tuning the model
Let’s try to tune the cost_complexity
of the decision tree to find a more optimal complexity. We use the class_tree_spec
object and use the set_args()
function to specify that we want to tune cost_complexity
. This is then passed directly into the workflow object to avoid creating an intermediate object. Also, since the dataset has 400 observations (rows), we’ll apply boostrapping to increase the sample number in each fold. WARNING: Bootstraps resample number has a direct relationship with execution time. For academic purposes, a value of 100 is used. However, in a development stage, the greater the number of resamples, the greater the statistical significance of the model.
set.seed(1234)
<- 100
bootstraps_samples <- bootstraps(carseats_train, times = bootstraps_samples, apparent = TRUE, strata = High)
carseats_boot
carseats_boot
## # Bootstrap sampling using stratification with apparent sample
## # A tibble: 101 × 2
## splits id
## <list> <chr>
## 1 <split [300/117]> Bootstrap001
## 2 <split [300/110]> Bootstrap002
## 3 <split [300/111]> Bootstrap003
## 4 <split [300/104]> Bootstrap004
## 5 <split [300/107]> Bootstrap005
## 6 <split [300/107]> Bootstrap006
## 7 <split [300/112]> Bootstrap007
## 8 <split [300/117]> Bootstrap008
## 9 <split [300/107]> Bootstrap009
## 10 <split [300/117]> Bootstrap010
## # ℹ 91 more rows
To be able to tune the variable we need 2 more objects. With the resamples
object, we will use a k-fold bootstrap data set, and a grid of values to try. Since we are only tuning 2 hyperparameters it is fine to stay with a regular grid.
<- decision_tree(
tree_spec cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()
%>%
) set_engine("rpart") %>%
set_mode("classification")
Setup parallel processing —-
## [1] 9
<- grid_regular(cost_complexity(range = c(-4, -1)),
tree_grid tree_depth(range = c(3, 7)),
min_n(range = c(10, 20)),
levels = 5
)
# set.seed(2001)
# tune_res <- tune_grid(
# tree_spec,
# High ~ .,
# resamples = carseats_boot,
# grid = tree_grid,
# metrics = metric_set(accuracy)
# )
# save tune_res
# write_rds(tune_res, 'data/08_tree_tune_grid_results.rds')
<- read_rds('data/08_tree_tune_grid_results.rds')
tune_res tune_res
## # Tuning results
## # Bootstrap sampling using stratification with apparent sample
## # A tibble: 101 × 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [300/300]> Apparent <tibble [125 × 7]> <tibble [0 × 3]>
## 2 <split [300/117]> Bootstrap001 <tibble [125 × 7]> <tibble [0 × 3]>
## 3 <split [300/110]> Bootstrap002 <tibble [125 × 7]> <tibble [0 × 3]>
## 4 <split [300/111]> Bootstrap003 <tibble [125 × 7]> <tibble [0 × 3]>
## 5 <split [300/104]> Bootstrap004 <tibble [125 × 7]> <tibble [0 × 3]>
## 6 <split [300/107]> Bootstrap005 <tibble [125 × 7]> <tibble [0 × 3]>
## 7 <split [300/107]> Bootstrap006 <tibble [125 × 7]> <tibble [0 × 3]>
## 8 <split [300/112]> Bootstrap007 <tibble [125 × 7]> <tibble [0 × 3]>
## 9 <split [300/117]> Bootstrap008 <tibble [125 × 7]> <tibble [0 × 3]>
## 10 <split [300/107]> Bootstrap009 <tibble [125 × 7]> <tibble [0 × 3]>
## # ℹ 91 more rows