Naive Bayes in R
To implement naive Bayes classification in R
, we’ll use the naiveBayes()
function in the e1071
package (Meyer et al. 2021)
Predictions
## Adelie Chinstrap Gentoo
## [1,] 0.000169 0.397831 0.602
## [1] Gentoo
## Levels: Adelie Chinstrap Gentoo
## Adelie Chinstrap Gentoo
## [1,] 0.000345 0.994868 0.004787
## [1] Chinstrap
## Levels: Adelie Chinstrap Gentoo
Confusion Matrices
penguins <- penguins %>% # keep magrittr pipe?
mutate(class_1 = predict(naive_model_1, newdata = .),
class_2 = predict(naive_model_2, newdata = .))
set.seed(84735)
penguins |>
sample_n(4) |>
select(bill_length_mm, flipper_length_mm, species, class_1, class_2) |>
rename(bill = bill_length_mm, flipper = flipper_length_mm)
## # A tibble: 4 × 5
## bill flipper species class_1 class_2
## <dbl> <int> <fct> <fct> <fct>
## 1 47.5 199 Chinstrap Gentoo Chinstrap
## 2 40.9 214 Gentoo Adelie Gentoo
## 3 41.3 194 Adelie Adelie Adelie
## 4 38.5 190 Adelie Adelie Adelie
# Confusion matrix for naive_model_1
penguins |>
tabyl(species, class_1) |>
adorn_percentages("row") |>
adorn_pct_formatting(digits = 2) |>
adorn_ns()
## species Adelie Chinstrap Gentoo
## Adelie 95.39% (145) 0.00% (0) 4.61% (7)
## Chinstrap 5.88% (4) 8.82% (6) 85.29% (58)
## Gentoo 6.45% (8) 4.84% (6) 88.71% (110)
- accuracy: 76 percent
- 85 percent of Chinstap penguins are misclassified as Gentoo!
# Confusion matrix for naive_model_2
penguins |>
tabyl(species, class_2) |>
adorn_percentages("row") |>
adorn_pct_formatting(digits = 2) |>
adorn_ns()
## species Adelie Chinstrap Gentoo
## Adelie 96.05% (146) 2.63% (4) 1.32% (2)
## Chinstrap 7.35% (5) 86.76% (59) 5.88% (4)
## Gentoo 0.81% (1) 0.81% (1) 98.39% (122)
- accuracy: 95 percent
Cross-Validation
# 10-fold cross-validation
set.seed(84735)
cv_model_2 <- naive_classification_summary_cv(
model = naive_model_2, data = penguins, y = "species", k = 10)
## species Adelie Chinstrap Gentoo
## Adelie 96.05% (146) 2.63% (4) 1.32% (2)
## Chinstrap 7.35% (5) 86.76% (59) 5.88% (4)
## Gentoo 0.81% (1) 0.81% (1) 98.39% (122)