This report aims to present the capabilities of the package flashlight

The document is a part of the paper “Landscape of R packages for eXplainable Machine Learning”, S. Maksymiuk, A. Gosiewska, and P. Biecek. (https://arxiv.org/abs/2009.13248). It contains a real life use-case with a hand of titanic_imputed data set described in Section Example gallery for XAI packages of the article.

We did our best to show the entire range of the implemented explanations. Please note that the examples may be incomplete. If you think something is missing, feel free to make a pull request at the GitHub repository MI2DataLab/XAI-tools.

The list of use-cases for all packages included in the article is here.

Load titanic_imputed data set.

data(titanic_imputed, package = "DALEX")

head(titanic_imputed)
##   gender age class    embarked  fare sibsp parch survived
## 1   male  42   3rd Southampton  7.11     0     0        0
## 2   male  13   3rd Southampton 20.05     0     2        0
## 3   male  16   3rd Southampton 20.05     1     1        0
## 4 female  39   3rd Southampton 20.05     1     1        1
## 5 female  16   3rd Southampton  7.13     0     0        1
## 6   male  25   3rd Southampton  7.13     0     0        1
library(flashlight)
library(MetricsWeighted)

Fit a forest type model and decision tree to the titanic imputed data.

ranger_model <- ranger::ranger(survived~., data = titanic_imputed, classification = TRUE, probability = TRUE)
tree_model <- rpart::rpart(as.factor(survived)~., data = titanic_imputed)

Model Parts

Permutation feature importance - One model

custom_predict <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,1]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict)
imp <- light_importance(fl, m_repetitions = 10)
plot(imp, fill = "darkred")

Permutation feature importance - Two models

custom_predict_ranger <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,1]
}

custom_predict_rpart <- function(X.model, new_data) {
  predict(X.model, new_data)[,1]
}

fl_ranger <- flashlight(model = ranger_model,  label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict_ranger)

fl_rpart <- flashlight(model = tree_model,  label = "Titanic Tree",
                 metrics = list(auc = AUC), predict_function = custom_predict_rpart)

fl <- multiflashlight(list(fl_ranger, fl_rpart), data = titanic_imputed, y = "survived")

imp <- light_importance(fl, m_repetitions = 10)
plot(imp, fill = "darkred")

Interactions - One model

custom_predict <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,1]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict)
st_1 <- light_interaction(fl, seed = 123)
plot(st_1, fill = "darkred")

Interactions - TWo model

custom_predict_ranger <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,1]
}

custom_predict_rpart <- function(X.model, new_data) {
  predict(X.model, new_data)[,1]
}

fl_ranger <- flashlight(model = ranger_model,  label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict_ranger)

fl_rpart <- flashlight(model = tree_model,  label = "Titanic Tree",
                 metrics = list(auc = AUC), predict_function = custom_predict_rpart)

fl <- multiflashlight(list(fl_ranger, fl_rpart), data = titanic_imputed, y = "survived")

st_2 <- light_interaction(fl, seed = 123)
plot(st_2, fill = "darkred")

Pairwise Interactions - One model

custom_predict <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,1]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict)
st_1 <- light_interaction(fl, seed = 123)
stp <- light_interaction(fl, seed = 123, v = most_important(st_1, 4), pairwise = TRUE)
plot(stp, fill = "darkred")

Interactions - TWo model

custom_predict_ranger <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,1]
}

custom_predict_rpart <- function(X.model, new_data) {
  predict(X.model, new_data)[,1]
}

fl_ranger <- flashlight(model = ranger_model,  label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict_ranger)

fl_rpart <- flashlight(model = tree_model,  label = "Titanic Tree",
                 metrics = list(auc = AUC), predict_function = custom_predict_rpart)

fl <- multiflashlight(list(fl_ranger, fl_rpart), data = titanic_imputed, y = "survived")
st_2 <- light_interaction(fl, seed = 123)

stp <- light_interaction(fl, seed = 123, v = most_important(st_2, 4), pairwise = TRUE)
plot(stp, fill = "darkred")

Model Profile

ALE Plot - One model

custom_predict <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,1]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict)

ale <- light_profile(fl, v = "fare", type = "ale")
plot(ale)

ALE Plot - Two models

custom_predict_ranger <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,1]
}

custom_predict_rpart <- function(X.model, new_data) {
  predict(X.model, new_data)[,1]
}

fl_ranger <- flashlight(model = ranger_model,  label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict_ranger)

fl_rpart <- flashlight(model = tree_model,  label = "Titanic Tree",
                 metrics = list(auc = AUC), predict_function = custom_predict_rpart)

fl <- multiflashlight(list(fl_ranger, fl_rpart), data = titanic_imputed, y = "survived")

ale <- light_profile(fl, v = "fare", type = "ale")
plot(ale)

Partial Dependence Plot - One model

custom_predict <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,2]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict)

pdp <- light_profile(fl, v = "fare", type = "partial dependence")
plot(pdp)

Partial Dependence Plot - Two models

custom_predict_ranger <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,2]
}

custom_predict_rpart <- function(X.model, new_data) {
  predict(X.model, new_data)[,1]
}

fl_ranger <- flashlight(model = ranger_model,  label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict_ranger)

fl_rpart <- flashlight(model = tree_model,  label = "Titanic Tree",
                 metrics = list(auc = AUC), predict_function = custom_predict_rpart)

fl <- multiflashlight(list(fl_ranger, fl_rpart), data = titanic_imputed, y = "survived")

pdp <- light_profile(fl, v = "fare", type = "partial dependence")
plot(pdp)

Predicted value profile - One model

custom_predict <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,2]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict)

pvp <- light_profile(fl, v = "fare", type = "predicted")
plot(pvp)

Predicted value profile - Two models

custom_predict_ranger <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,2]
}

custom_predict_rpart <- function(X.model, new_data) {
  predict(X.model, new_data)[,2]
}

fl_ranger <- flashlight(model = ranger_model,  label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict_ranger)

fl_rpart <- flashlight(model = tree_model,  label = "Titanic Tree",
                 metrics = list(auc = AUC), predict_function = custom_predict_rpart)

fl <- multiflashlight(list(fl_ranger, fl_rpart), data = titanic_imputed, y = "survived")

pvp <- light_profile(fl, v = "fare", type = "predicted")
plot(pvp)

Response profile

custom_predict <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,2]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict)

pvp <- light_profile(fl, v = "fare", type = "response")
plot(pvp)

Residual profile - One model

custom_predict <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,2]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict)

rvp <- light_profile(fl, v = "fare", type = "residual")
plot(rvp)

Residual profile - Two models

custom_predict_ranger <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,2]
}

custom_predict_rpart <- function(X.model, new_data) {
  predict(X.model, new_data)[,2]
}

fl_ranger <- flashlight(model = ranger_model,  label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict_ranger)

fl_rpart <- flashlight(model = tree_model,  label = "Titanic Tree",
                 metrics = list(auc = AUC), predict_function = custom_predict_rpart)

fl <- multiflashlight(list(fl_ranger, fl_rpart), data = titanic_imputed, y = "survived")

rvp <- light_profile(fl, v = "fare", type = "residual")
plot(rvp)

Global surrogate model

custom_predict <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,2]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict)

surr <- light_global_surrogate(fl)
plot(surr)

Model diagnostics

Residual vs variables plots

custom_predict <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,2]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict)

plot(light_profile(fl, v = "fare", type = "residual", stats = "quartiles"))

Predict Parts

BreakDown

custom_predict <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,2]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict)

bd <- light_breakdown(fl, new_obs = titanic_imputed[1, ], n_max = 1000)
plot(bd)

Predict Profile

ICE - One model

custom_predict <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,2]
}
fl <- flashlight(model = ranger_model, data = titanic_imputed, y = "survived", label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict)

cp <- light_ice(fl, v = "fare", seed = 123, n_max = 200)
plot(cp)

ICE - Two models

custom_predict_ranger <- function(X.model, new_data) {
  predict(X.model, new_data)$predictions[,2]
}

custom_predict_rpart <- function(X.model, new_data) {
  predict(X.model, new_data)[,2]
}

fl_ranger <- flashlight(model = ranger_model,  label = "Titanic Ranger",
                 metrics = list(auc = AUC), predict_function = custom_predict_ranger)

fl_rpart <- flashlight(model = tree_model,  label = "Titanic Tree",
                 metrics = list(auc = AUC), predict_function = custom_predict_rpart)

fl <- multiflashlight(list(fl_ranger, fl_rpart), data = titanic_imputed, y = "survived")

cp <- light_ice(fl, v = "fare", seed = 123, n_max = 200)
plot(cp)

Session info

sessionInfo()
## R version 3.6.1 (2019-07-05)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 18363)
## 
## Matrix products: default
## 
## locale:
## [1] LC_COLLATE=Polish_Poland.1250  LC_CTYPE=Polish_Poland.1250   
## [3] LC_MONETARY=Polish_Poland.1250 LC_NUMERIC=C                  
## [5] LC_TIME=Polish_Poland.1250    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] MetricsWeighted_0.5.0 flashlight_0.7.4     
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_1.0.4.6     pillar_1.4.4     compiler_3.6.1   ggpubr_0.2.5    
##  [5] tools_3.6.1      rpart_4.1-15     digest_0.6.25    lattice_0.20-40 
##  [9] evaluate_0.14    lifecycle_0.2.0  tibble_3.0.1     gtable_0.3.0    
## [13] pkgconfig_2.0.3  rlang_0.4.10     Matrix_1.2-18    yaml_2.2.1      
## [17] xfun_0.12        ranger_0.12.1    stringr_1.4.0    dplyr_1.0.4     
## [21] knitr_1.28       generics_0.0.2   vctrs_0.3.6      grid_3.6.1      
## [25] tidyselect_1.1.0 glue_1.4.1       R6_2.4.1         rmarkdown_2.1   
## [29] farver_2.0.3     tidyr_1.0.2      ggplot2_3.3.2    purrr_0.3.4     
## [33] magrittr_1.5     scales_1.1.1     ellipsis_0.3.1   htmltools_0.4.0 
## [37] rpart.plot_3.0.8 colorspace_1.4-1 ggsignif_0.6.0   labeling_0.3    
## [41] stringi_1.4.6    munsell_0.5.0    crayon_1.3.4