5.3 Class imbalance
In many instances, random splitting is not suitable. This includes datasets that contain class imbalance, meaning one class is dominated by another. Class imbalance is important to detect and take into consideration in data splitting. Performing random splitting on a dataset with severe class imbalance may cause the model to perform badly at validation. You want to avoid allocating the minority class disproportionately into the training or test set. The point is to have the same distribution across the training and test sets. Class imbalance can occur in differing degrees:
Splitting methods suited for datasets containing class imbalance should be considered. Let’s consider a #Tidytuesday dataset on Himalayan expedition members, which Julia Silge recently explored here using {tidymodels}.
library(tidyverse)
library(skimr)
<- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-09-22/members.csv")
members
skim(members)
Name | members |
Number of rows | 76519 |
Number of columns | 21 |
_______________________ | |
Column type frequency: | |
character | 10 |
logical | 6 |
numeric | 5 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
expedition_id | 0 | 1.00 | 9 | 9 | 0 | 10350 | 0 |
member_id | 0 | 1.00 | 12 | 12 | 0 | 76518 | 0 |
peak_id | 0 | 1.00 | 4 | 4 | 0 | 391 | 0 |
peak_name | 15 | 1.00 | 4 | 25 | 0 | 390 | 0 |
season | 0 | 1.00 | 6 | 7 | 0 | 5 | 0 |
sex | 2 | 1.00 | 1 | 1 | 0 | 2 | 0 |
citizenship | 10 | 1.00 | 2 | 23 | 0 | 212 | 0 |
expedition_role | 21 | 1.00 | 4 | 25 | 0 | 524 | 0 |
death_cause | 75413 | 0.01 | 3 | 27 | 0 | 12 | 0 |
injury_type | 74807 | 0.02 | 3 | 27 | 0 | 11 | 0 |
Variable type: logical
skim_variable | n_missing | complete_rate | mean | count |
---|---|---|---|---|
hired | 0 | 1 | 0.21 | FAL: 60788, TRU: 15731 |
success | 0 | 1 | 0.38 | FAL: 47320, TRU: 29199 |
solo | 0 | 1 | 0.00 | FAL: 76398, TRU: 121 |
oxygen_used | 0 | 1 | 0.24 | FAL: 58286, TRU: 18233 |
died | 0 | 1 | 0.01 | FAL: 75413, TRU: 1106 |
injured | 0 | 1 | 0.02 | FAL: 74806, TRU: 1713 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
year | 0 | 1.00 | 2000.36 | 14.78 | 1905 | 1991 | 2004 | 2012 | 2019 | ▁▁▁▃▇ |
age | 3497 | 0.95 | 37.33 | 10.40 | 7 | 29 | 36 | 44 | 85 | ▁▇▅▁▁ |
highpoint_metres | 21833 | 0.71 | 7470.68 | 1040.06 | 3800 | 6700 | 7400 | 8400 | 8850 | ▁▁▆▃▇ |
death_height_metres | 75451 | 0.01 | 6592.85 | 1308.19 | 400 | 5800 | 6600 | 7550 | 8830 | ▁▁▂▇▆ |
injury_height_metres | 75510 | 0.01 | 7049.91 | 1214.24 | 400 | 6200 | 7100 | 8000 | 8880 | ▁▁▂▇▇ |
Let’s say we were interested in predicting the likelihood of survival or death for an expedition member. It would be a good idea to check for class imbalance:
library(janitor)
%>%
members tabyl(died) %>%
adorn_totals("row")
## died n percent
## FALSE 75413 0.98554607
## TRUE 1106 0.01445393
## Total 76519 1.00000000
We can see that nearly 99% of people survive their expedition. This dataset would be ripe for a sampling technique adept at handling such extreme class imbalance. This technique is called stratified sampling, in which “the training/test split is conducted separately within each class and then these subsamples are combined into the overall training and test set”. Operationally, this is done by using the strata
argument inside initial_split()
:
set.seed(123)
<- initial_split(members, prop = 0.80, strata = died)
members_split <- training(members_split)
members_train <- testing(members_split) members_test
5.3.1 Stratified sampling simulation
With simulation we can see the effect of stratification: we expect that the expected value does not change with stratification but the variance is lower.
<- function(prop_in_dataset, n_resample = 50,
simulate_stratified_sampling n_rows = 1000, seed = 45678) {
set.seed(seed)
<- tibble(died = c(
data_to_split rep(1, floor(n_rows * prop_in_dataset)),
rep(0, floor(n_rows * (1 - prop_in_dataset)))
))
<- map_dfr(seq_len(n_resample), ~{
samples initial_split(data_to_split) %>%
testing() %>%
summarize(died_pct = mean(died))
})
<- map_dfr(seq_len(n_resample), ~{
stratified_samples initial_split(data_to_split, strata = died) %>%
testing() %>%
summarize(died_pct = mean(died))
})
rbind(
%>% mutate(stratified = FALSE),
samples %>% mutate(stratified = TRUE)
stratified_samples %>%
) group_by(stratified) %>%
summarize(mean = mean(died_pct), var = var(died_pct))
}
rsample does not stratify if class imbalance is more extreme than 10%
simulate_stratified_sampling(0.09)
## # A tibble: 2 × 3
## stratified mean var
## <lgl> <dbl> <dbl>
## 1 FALSE 0.0910 0.000240
## 2 TRUE 0.0917 0.000257
Stratified sampling happens:
simulate_stratified_sampling(0.11)
## # A tibble: 2 × 3
## stratified mean var
## <lgl> <dbl> <dbl>
## 1 FALSE 0.110 0.000257
## 2 TRUE 0.112 0