Using the tidymodels framework to create a decision tree classifier model for the attrition dataset.

  1. Train/test split attrition
# Create training (75%) and test (25%) sets for the
# rsample::attrition data. Use set.seed for reproducibility
set.seed(123)
churn_split <- initial_split(attrition, prop = 0.75, strata = "attrition")
train <- training(churn_split)
test  <- testing(churn_split)

# k-folds
folds <- vfold_cv(data = train, v = 3, strata = 'attrition')
folds
## #  3-fold cross-validation using stratification 
## # A tibble: 3 × 2
##   splits            id   
##   <list>            <chr>
## 1 <split [734/367]> Fold1
## 2 <split [734/367]> Fold2
## 3 <split [734/367]> Fold3
  1. Create a recipe to preprocess the train and test datasets
recipe_obj <- recipe(attrition ~ ., data = train) %>% 
  
  # remove all predictors with constant (zero) variance
  step_zv(all_predictors()) %>% 
  
  # upsample the target to balance the class
  step_upsample(attrition)

# verify recipe with train dataset
recipe_obj %>% 
  prep() %>% 
  bake(new_data = NULL) %>% 
  glimpse()
## Rows: 1,848
## Columns: 31
## $ age                        <int> 33, 59, 30, 36, 35, 29, 31, 29, 53, 38, 24,…
## $ business_travel            <fct> Travel_Frequently, Travel_Rarely, Travel_Ra…
## $ daily_rate                 <int> 1392, 1324, 1358, 1299, 809, 153, 670, 1389…
## $ department                 <fct> Research_Development, Research_Development,…
## $ distance_from_home         <int> 3, 3, 24, 27, 16, 15, 26, 21, 2, 2, 11, 7, …
## $ education                  <ord> Master, Bachelor, Below_College, Bachelor, …
## $ education_field            <fct> Life_Sciences, Medical, Life_Sciences, Medi…
## $ environment_satisfaction   <ord> Very_High, High, Very_High, High, Low, Very…
## $ gender                     <fct> Female, Female, Male, Male, Male, Female, M…
## $ hourly_rate                <int> 56, 81, 67, 94, 84, 49, 31, 51, 78, 45, 96,…
## $ job_involvement            <ord> High, Very_High, High, High, Very_High, Med…
## $ job_level                  <int> 1, 1, 1, 2, 1, 2, 1, 3, 4, 1, 2, 3, 1, 3, 5…
## $ job_role                   <fct> Research_Scientist, Laboratory_Technician, …
## $ job_satisfaction           <ord> High, Low, High, High, Medium, High, High, …
## $ marital_status             <fct> Married, Married, Divorced, Married, Marrie…
## $ monthly_income             <int> 2909, 2670, 2693, 5237, 2426, 4193, 2911, 9…
## $ monthly_rate               <int> 23159, 9964, 13335, 16577, 16479, 12682, 15…
## $ num_companies_worked       <int> 1, 4, 1, 6, 0, 0, 1, 1, 2, 5, 0, 0, 1, 3, 3…
## $ over_time                  <fct> Yes, Yes, No, No, No, Yes, No, No, No, Yes,…
## $ percent_salary_hike        <int> 11, 20, 22, 13, 13, 12, 17, 11, 16, 11, 18,…
## $ performance_rating         <ord> Excellent, Outstanding, Outstanding, Excell…
## $ relationship_satisfaction  <ord> High, Low, Medium, Medium, High, Very_High,…
## $ stock_option_level         <int> 0, 3, 1, 2, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0…
## $ total_working_years        <int> 8, 12, 1, 17, 6, 10, 5, 10, 31, 6, 5, 13, 0…
## $ training_times_last_year   <int> 3, 3, 2, 3, 5, 3, 1, 1, 3, 3, 5, 4, 6, 4, 2…
## $ work_life_balance          <ord> Better, Good, Better, Good, Better, Better,…
## $ years_at_company           <int> 8, 1, 1, 7, 5, 9, 5, 10, 25, 3, 4, 12, 0, 2…
## $ years_in_current_role      <int> 7, 0, 0, 7, 4, 5, 2, 9, 8, 2, 2, 6, 0, 6, 2…
## $ years_since_last_promotion <int> 3, 0, 0, 7, 0, 0, 4, 8, 3, 1, 1, 2, 0, 5, 2…
## $ years_with_curr_manager    <int> 0, 0, 0, 7, 3, 8, 3, 8, 7, 2, 3, 11, 0, 17,…
## $ attrition                  <fct> No, No, No, No, No, No, No, No, No, No, No,…
# verify recipe with test dataset
recipe_obj %>%
  prep() %>%
  bake(new_data = test) %>% 
  glimpse()
## Rows: 369
## Columns: 31
## $ age                        <int> 49, 27, 32, 38, 34, 32, 22, 53, 42, 35, 33,…
## $ business_travel            <fct> Travel_Frequently, Travel_Rarely, Travel_Fr…
## $ daily_rate                 <int> 279, 591, 1005, 216, 1346, 334, 1123, 1282,…
## $ department                 <fct> Research_Development, Research_Development,…
## $ distance_from_home         <int> 8, 2, 2, 23, 19, 5, 16, 5, 8, 4, 1, 1, 7, 2…
## $ education                  <ord> Below_College, Below_College, College, Bach…
## $ education_field            <fct> Life_Sciences, Medical, Life_Sciences, Life…
## $ environment_satisfaction   <ord> High, Low, Very_High, Very_High, Medium, Lo…
## $ gender                     <fct> Male, Male, Male, Male, Male, Male, Male, F…
## $ hourly_rate                <int> 61, 40, 79, 44, 93, 80, 96, 58, 48, 75, 98,…
## $ job_involvement            <ord> Medium, High, High, Medium, High, Very_High…
## $ job_level                  <int> 2, 1, 1, 3, 1, 1, 1, 5, 2, 1, 3, 2, 5, 3, 4…
## $ job_role                   <fct> Research_Scientist, Laboratory_Technician, …
## $ job_satisfaction           <ord> Medium, Medium, Very_High, High, Very_High,…
## $ marital_status             <fct> Married, Married, Single, Single, Divorced,…
## $ monthly_income             <int> 5130, 3468, 3068, 9526, 2661, 3298, 2935, 1…
## $ monthly_rate               <int> 24907, 16632, 11864, 8787, 8758, 15053, 732…
## $ num_companies_worked       <int> 1, 9, 0, 0, 0, 0, 1, 4, 0, 1, 1, 1, 5, 7, 2…
## $ over_time                  <fct> No, No, No, No, No, Yes, Yes, No, No, No, Y…
## $ percent_salary_hike        <int> 23, 12, 13, 21, 11, 12, 13, 11, 11, 12, 12,…
## $ performance_rating         <ord> Outstanding, Excellent, Excellent, Outstand…
## $ relationship_satisfaction  <ord> Very_High, Very_High, High, Medium, High, V…
## $ stock_option_level         <int> 1, 1, 0, 0, 1, 2, 2, 1, 1, 1, 0, 1, 1, 0, 3…
## $ total_working_years        <int> 10, 6, 8, 10, 3, 7, 1, 26, 10, 1, 15, 7, 29…
## $ training_times_last_year   <int> 3, 3, 2, 2, 2, 5, 2, 3, 2, 3, 1, 2, 2, 3, 2…
## $ work_life_balance          <ord> Better, Better, Good, Better, Better, Good,…
## $ years_at_company           <int> 10, 2, 7, 9, 2, 6, 1, 14, 9, 1, 15, 7, 27, …
## $ years_in_current_role      <int> 7, 2, 7, 7, 2, 2, 0, 13, 7, 0, 14, 5, 3, 16…
## $ years_since_last_promotion <int> 1, 2, 3, 1, 1, 0, 0, 4, 4, 0, 8, 0, 13, 7, …
## $ years_with_curr_manager    <int> 7, 2, 6, 8, 2, 5, 0, 8, 2, 0, 12, 7, 8, 9, …
## $ attrition                  <fct> No, No, No, No, No, No, No, No, No, No, No,…
  1. Create a model specification
# decision tree spec (tuning hyperparameters)
tree_spec <- decision_tree(
     cost_complexity = tune(),
     tree_depth = tune(),
     min_n = tune()
) %>%
     set_engine("rpart") %>%
     set_mode("classification")

tree_spec
## Decision Tree Model Specification (classification)
## 
## Main Arguments:
##   cost_complexity = tune()
##   tree_depth = tune()
##   min_n = tune()
## 
## Computational engine: rpart
  1. Create a workflow object to combine the model spec and recipe objects
# workflow spec
wflw_tree <- workflow() %>%
     add_model(tree_spec) %>%
     add_recipe(recipe_obj)
  1. Create a hyperparameter tuning grid
# tuning grid
tree_grid <- grid_regular(
     cost_complexity(),
     tree_depth(range = c(5, 9)),
     min_n(range = c(11, 15)),
     levels = 5)

tree_grid
## # A tibble: 125 × 3
##    cost_complexity tree_depth min_n
##              <dbl>      <int> <int>
##  1    0.0000000001          5    11
##  2    0.0000000178          5    11
##  3    0.00000316            5    11
##  4    0.000562              5    11
##  5    0.1                   5    11
##  6    0.0000000001          6    11
##  7    0.0000000178          6    11
##  8    0.00000316            6    11
##  9    0.000562              6    11
## 10    0.1                   6    11
## # ℹ 115 more rows