Using the tidymodels
framework to create a decision tree classifier model for the attrition
dataset.
- Train/test split
attrition
# Create training (75%) and test (25%) sets for the
# rsample::attrition data. Use set.seed for reproducibility
set.seed(123)
<- initial_split(attrition, prop = 0.75, strata = "attrition")
churn_split <- training(churn_split)
train <- testing(churn_split)
test
# k-folds
<- vfold_cv(data = train, v = 3, strata = 'attrition')
folds folds
## # 3-fold cross-validation using stratification
## # A tibble: 3 × 2
## splits id
## <list> <chr>
## 1 <split [734/367]> Fold1
## 2 <split [734/367]> Fold2
## 3 <split [734/367]> Fold3
- Create a recipe to preprocess the train and test datasets
<- recipe(attrition ~ ., data = train) %>%
recipe_obj
# remove all predictors with constant (zero) variance
step_zv(all_predictors()) %>%
# upsample the target to balance the class
step_upsample(attrition)
# verify recipe with train dataset
%>%
recipe_obj prep() %>%
bake(new_data = NULL) %>%
glimpse()
## Rows: 1,848
## Columns: 31
## $ age <int> 33, 59, 30, 36, 35, 29, 31, 29, 53, 38, 24,…
## $ business_travel <fct> Travel_Frequently, Travel_Rarely, Travel_Ra…
## $ daily_rate <int> 1392, 1324, 1358, 1299, 809, 153, 670, 1389…
## $ department <fct> Research_Development, Research_Development,…
## $ distance_from_home <int> 3, 3, 24, 27, 16, 15, 26, 21, 2, 2, 11, 7, …
## $ education <ord> Master, Bachelor, Below_College, Bachelor, …
## $ education_field <fct> Life_Sciences, Medical, Life_Sciences, Medi…
## $ environment_satisfaction <ord> Very_High, High, Very_High, High, Low, Very…
## $ gender <fct> Female, Female, Male, Male, Male, Female, M…
## $ hourly_rate <int> 56, 81, 67, 94, 84, 49, 31, 51, 78, 45, 96,…
## $ job_involvement <ord> High, Very_High, High, High, Very_High, Med…
## $ job_level <int> 1, 1, 1, 2, 1, 2, 1, 3, 4, 1, 2, 3, 1, 3, 5…
## $ job_role <fct> Research_Scientist, Laboratory_Technician, …
## $ job_satisfaction <ord> High, Low, High, High, Medium, High, High, …
## $ marital_status <fct> Married, Married, Divorced, Married, Marrie…
## $ monthly_income <int> 2909, 2670, 2693, 5237, 2426, 4193, 2911, 9…
## $ monthly_rate <int> 23159, 9964, 13335, 16577, 16479, 12682, 15…
## $ num_companies_worked <int> 1, 4, 1, 6, 0, 0, 1, 1, 2, 5, 0, 0, 1, 3, 3…
## $ over_time <fct> Yes, Yes, No, No, No, Yes, No, No, No, Yes,…
## $ percent_salary_hike <int> 11, 20, 22, 13, 13, 12, 17, 11, 16, 11, 18,…
## $ performance_rating <ord> Excellent, Outstanding, Outstanding, Excell…
## $ relationship_satisfaction <ord> High, Low, Medium, Medium, High, Very_High,…
## $ stock_option_level <int> 0, 3, 1, 2, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0…
## $ total_working_years <int> 8, 12, 1, 17, 6, 10, 5, 10, 31, 6, 5, 13, 0…
## $ training_times_last_year <int> 3, 3, 2, 3, 5, 3, 1, 1, 3, 3, 5, 4, 6, 4, 2…
## $ work_life_balance <ord> Better, Good, Better, Good, Better, Better,…
## $ years_at_company <int> 8, 1, 1, 7, 5, 9, 5, 10, 25, 3, 4, 12, 0, 2…
## $ years_in_current_role <int> 7, 0, 0, 7, 4, 5, 2, 9, 8, 2, 2, 6, 0, 6, 2…
## $ years_since_last_promotion <int> 3, 0, 0, 7, 0, 0, 4, 8, 3, 1, 1, 2, 0, 5, 2…
## $ years_with_curr_manager <int> 0, 0, 0, 7, 3, 8, 3, 8, 7, 2, 3, 11, 0, 17,…
## $ attrition <fct> No, No, No, No, No, No, No, No, No, No, No,…
# verify recipe with test dataset
%>%
recipe_obj prep() %>%
bake(new_data = test) %>%
glimpse()
## Rows: 369
## Columns: 31
## $ age <int> 49, 27, 32, 38, 34, 32, 22, 53, 42, 35, 33,…
## $ business_travel <fct> Travel_Frequently, Travel_Rarely, Travel_Fr…
## $ daily_rate <int> 279, 591, 1005, 216, 1346, 334, 1123, 1282,…
## $ department <fct> Research_Development, Research_Development,…
## $ distance_from_home <int> 8, 2, 2, 23, 19, 5, 16, 5, 8, 4, 1, 1, 7, 2…
## $ education <ord> Below_College, Below_College, College, Bach…
## $ education_field <fct> Life_Sciences, Medical, Life_Sciences, Life…
## $ environment_satisfaction <ord> High, Low, Very_High, Very_High, Medium, Lo…
## $ gender <fct> Male, Male, Male, Male, Male, Male, Male, F…
## $ hourly_rate <int> 61, 40, 79, 44, 93, 80, 96, 58, 48, 75, 98,…
## $ job_involvement <ord> Medium, High, High, Medium, High, Very_High…
## $ job_level <int> 2, 1, 1, 3, 1, 1, 1, 5, 2, 1, 3, 2, 5, 3, 4…
## $ job_role <fct> Research_Scientist, Laboratory_Technician, …
## $ job_satisfaction <ord> Medium, Medium, Very_High, High, Very_High,…
## $ marital_status <fct> Married, Married, Single, Single, Divorced,…
## $ monthly_income <int> 5130, 3468, 3068, 9526, 2661, 3298, 2935, 1…
## $ monthly_rate <int> 24907, 16632, 11864, 8787, 8758, 15053, 732…
## $ num_companies_worked <int> 1, 9, 0, 0, 0, 0, 1, 4, 0, 1, 1, 1, 5, 7, 2…
## $ over_time <fct> No, No, No, No, No, Yes, Yes, No, No, No, Y…
## $ percent_salary_hike <int> 23, 12, 13, 21, 11, 12, 13, 11, 11, 12, 12,…
## $ performance_rating <ord> Outstanding, Excellent, Excellent, Outstand…
## $ relationship_satisfaction <ord> Very_High, Very_High, High, Medium, High, V…
## $ stock_option_level <int> 1, 1, 0, 0, 1, 2, 2, 1, 1, 1, 0, 1, 1, 0, 3…
## $ total_working_years <int> 10, 6, 8, 10, 3, 7, 1, 26, 10, 1, 15, 7, 29…
## $ training_times_last_year <int> 3, 3, 2, 2, 2, 5, 2, 3, 2, 3, 1, 2, 2, 3, 2…
## $ work_life_balance <ord> Better, Better, Good, Better, Better, Good,…
## $ years_at_company <int> 10, 2, 7, 9, 2, 6, 1, 14, 9, 1, 15, 7, 27, …
## $ years_in_current_role <int> 7, 2, 7, 7, 2, 2, 0, 13, 7, 0, 14, 5, 3, 16…
## $ years_since_last_promotion <int> 1, 2, 3, 1, 1, 0, 0, 4, 4, 0, 8, 0, 13, 7, …
## $ years_with_curr_manager <int> 7, 2, 6, 8, 2, 5, 0, 8, 2, 0, 12, 7, 8, 9, …
## $ attrition <fct> No, No, No, No, No, No, No, No, No, No, No,…
- Create a model specification
# decision tree spec (tuning hyperparameters)
<- decision_tree(
tree_spec cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()
%>%
) set_engine("rpart") %>%
set_mode("classification")
tree_spec
## Decision Tree Model Specification (classification)
##
## Main Arguments:
## cost_complexity = tune()
## tree_depth = tune()
## min_n = tune()
##
## Computational engine: rpart
- Create a workflow object to combine the model spec and recipe objects
# workflow spec
<- workflow() %>%
wflw_tree add_model(tree_spec) %>%
add_recipe(recipe_obj)
- Create a hyperparameter tuning grid
# tuning grid
<- grid_regular(
tree_grid cost_complexity(),
tree_depth(range = c(5, 9)),
min_n(range = c(11, 15)),
levels = 5)
tree_grid
## # A tibble: 125 × 3
## cost_complexity tree_depth min_n
## <dbl> <int> <int>
## 1 0.0000000001 5 11
## 2 0.0000000178 5 11
## 3 0.00000316 5 11
## 4 0.000562 5 11
## 5 0.1 5 11
## 6 0.0000000001 6 11
## 7 0.0000000178 6 11
## 8 0.00000316 6 11
## 9 0.000562 6 11
## 10 0.1 6 11
## # ℹ 115 more rows