7.3 Modeling with workflows

trees_split <- initial_split(trees_cleaned %>% mutate(diam = log10(diam)), prop = 0.8)
trees_training <- training(trees_split)
trees_testing <- testing(trees_split)
trees_recipe <- recipe(trees_training, diam ~ .) %>%
  update_role(tree_id, address, new_role = "id") %>%
  step_indicate_na(date_planted) %>%
  # really dummy imputation
  step_mutate(date_planted = if_else(!is.na(date_planted), date_planted, as.Date('1950-01-01'))) %>%
  step_other(all_nominal_predictors(), threshold = 0.01) %>%
  step_dummy(all_nominal_predictors())

linear_model_spec <- linear_reg() %>% set_engine("lm")

tree_workflow_lm <- workflow() %>%
  add_model(linear_model_spec) %>%
  add_recipe(trees_recipe)
fitted_workflow_lm <- tree_workflow_lm %>% fit(trees_training)

tidy(extract_recipe(fitted_workflow_lm), 3)
## # A tibble: 35 × 3
##    terms        retained                                                   id   
##    <chr>        <chr>                                                      <chr>
##  1 legal_status DPW Maintained                                             othe…
##  2 legal_status Permitted Site                                             othe…
##  3 legal_status Undocumented                                               othe…
##  4 species      Acacia melanoxylon :: Blackwood Acacia                     othe…
##  5 species      Arbutus 'Marina' :: Hybrid Strawberry Tree                 othe…
##  6 species      Callistemon citrinus :: Lemon Bottlebrush                  othe…
##  7 species      Corymbia ficifolia :: Red Flowering Gum                    othe…
##  8 species      Cupressus macrocarpa :: Monterey Cypress                   othe…
##  9 species      Eriobotrya deflexa :: Bronze Loquat                        othe…
## 10 species      Ficus microcarpa nitida 'Green Gem' :: Indian Laurel Fig … othe…
## # ℹ 25 more rows
tidy(extract_fit_parsnip(fitted_workflow_lm))
## # A tibble: 41 × 5
##    term                                   estimate std.error statistic   p.value
##    <chr>                                     <dbl>     <dbl>     <dbl>     <dbl>
##  1 (Intercept)                             6.35e+1   4.22e+0     15.1  3.57e- 51
##  2 site_order                             -3.08e-3   1.15e-4    -26.9  1.66e-158
##  3 date_planted                           -4.19e-5   4.83e-7    -86.8  0        
##  4 latitude                                5.56e-1   4.03e-2     13.8  2.75e- 43
##  5 longitude                               6.78e-1   3.28e-2     20.7  6.99e- 95
##  6 na_ind_date_planted                    -6.22e-1   9.62e-3    -64.7  0        
##  7 legal_status_Permitted.Site             1.07e-1   4.51e-3     23.8  4.47e-125
##  8 legal_status_Undocumented               7.55e-2   6.51e-3     11.6  4.43e- 31
##  9 legal_status_other                      6.61e-2   9.22e-3      7.16 7.88e- 13
## 10 species_Arbutus..Marina.....Hybrid.St… -3.67e-1   8.20e-3    -44.8  0        
## # ℹ 31 more rows
trees_testing$pred_lm <- predict(fitted_workflow_lm, trees_testing)$.pred

rmse(trees_testing, diam, pred_lm)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard       0.316

7.3.1 Different model, same recipe

rand_forest_spec <- rand_forest(
  mode = 'regression',
  mtry = 3,
  trees = 50,
  min_n = 10
) %>% 
  set_engine('ranger')

tree_workflow_rf <- tree_workflow_lm %>%
  update_model(rand_forest_spec)

fitted_workflow_rf <- tree_workflow_rf %>% fit(trees_training)

trees_testing$pred_rf <- predict(fitted_workflow_rf, trees_testing)$.pred

rmse(trees_testing, diam, pred_lm)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard       0.316
rmse(trees_testing, diam, pred_rf)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard       0.307

7.3.2 Same model, different preprocessing

formula_predictions <- tree_workflow_lm %>%
  remove_recipe() %>% 
  add_formula(diam ~ is.na(date_planted) + longitude) %>%
  fit(trees_training) %>%
  predict(trees_testing)

rmse_vec(trees_testing$diam, formula_predictions$.pred)
## [1] 0.3574375