This report aims to present the capabilities of the package survxai.

The document is a part of the paper “Landscape of R packages for eXplainable Machine Learning”, S. Maksymiuk, A. Gosiewska, and P. Biecek. (https://arxiv.org/abs/2009.13248). It contains a real life use-case with a hand of titanic_imputed data set described in Section Example gallery for XAI packages of the article.

We did our best to show the entire range of the implemented explanations. Please note that the examples may be incomplete. If you think something is missing, feel free to make a pull request at the GitHub repository MI2DataLab/XAI-tools.

The list of use-cases for all packages included in the article is here.

Load pbc data set.

data(pbc, package = "randomForestSRC")
pbc <- pbc[complete.cases(pbc),]
pbc$sex <- as.factor(pbc$sex)
pbc$stage <- as.factor(pbc$stage)

pbc_smaller <- pbc[,c("days", "status", "treatment", "sex", "age", "bili", "stage")]
pbc_smaller$years <- pbc_smaller$days/356
pbc_smaller <- pbc_smaller[,-1]
head(pbc_smaller)
##   status treatment sex   age bili stage     years
## 1      1         1   1 21464 14.5     4  1.123596
## 2      0         1   1 20617  1.1     3 12.640449
## 3      1         1   0 25594  1.4     4  2.842697
## 4      1         1   1 19994  1.8     4  5.407303
## 5      0         2   1 13918  3.4     3  4.224719
## 7      0         2   1 20284  1.0     3  5.146067
library(survxai)
library(rms)
library(randomForestSRC)

set.seed(123)

Fit a Cox PH model.

cph_model <- cph(Surv(years, status)~., data = pbc_smaller, surv = TRUE, x = TRUE, y=TRUE)
rf_model <- rfsrc(Surv(years, status)~., data = pbc_smaller)

Model profile

PDP

surve_cph <- explain(model = cph_model,
                     data = pbc_smaller[,-c(1,7)], 
                     y = Surv(pbc_smaller$years, pbc_smaller$status))
vr_cph_sex <- variable_response(surve_cph, "sex")
plot(vr_cph_sex)

Comparison of models

surve_cph <- explain(model = cph_model,
                     data = pbc_smaller[,-c(1,7)], 
                     y = Surv(pbc_smaller$years, pbc_smaller$status))
surve_rf <- explain(model = rf_model, 
                     label = "random forest",
                     data = pbc_smaller[,-c(1,7)], 
                     y = Surv(pbc_smaller$years, pbc_smaller$status))
vr_cph_sex <- variable_response(surve_cph, "sex")
vr_rf_sex <- variable_response(surve_rf, "sex")

plot(vr_cph_sex, vr_rf_sex)

Model Diagnostics

Prediction Error Curve for Brier Score

surve_cph <- explain(model = cph_model,
                     data = pbc_smaller[,-c(1,7)], 
                     y = Surv(pbc_smaller$years, pbc_smaller$status))
mp_cph <- model_performance(surve_cph)
plot(mp_cph)

Predit parts

Break Down

surve_cph <- explain(model = cph_model,
                     data = pbc_smaller[,-c(1,7)], 
                     y = Surv(pbc_smaller$years, pbc_smaller$status))
broken_prediction_cph <- prediction_breakdown(surve_cph, pbc_smaller[1,])
plot(broken_prediction_cph)

Predict profile

Ceteris paribus

surve_cph <- explain(model = cph_model,
                     data = pbc_smaller[,-c(1,7)], 
                     y = Surv(pbc_smaller$years, pbc_smaller$status))

cp_cph <- ceteris_paribus(surve_cph, pbc_smaller[1,])
plot(cp_cph)

Session info

sessionInfo()
## R version 3.6.1 (2019-07-05)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 18363)
## 
## Matrix products: default
## 
## locale:
## [1] LC_COLLATE=Polish_Poland.1250  LC_CTYPE=Polish_Poland.1250   
## [3] LC_MONETARY=Polish_Poland.1250 LC_NUMERIC=C                  
## [5] LC_TIME=Polish_Poland.1250    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] randomForestSRC_2.9.3 rms_5.1-4             SparseM_1.78         
##  [4] Hmisc_4.3-1           ggplot2_3.3.2         Formula_1.2-3        
##  [7] survival_3.1-11       lattice_0.20-40       survxai_0.2.2        
## [10] prodlim_2019.11.13   
## 
## loaded via a namespace (and not attached):
##  [1] tidyr_1.0.2         splines_3.6.1       foreach_1.4.8      
##  [4] latticeExtra_0.6-29 yaml_2.2.1          pec_2019.11.03     
##  [7] timereg_1.9.6       numDeriv_2016.8-1.1 pillar_1.4.4       
## [10] backports_1.1.8     quantreg_5.54       glue_1.4.1         
## [13] digest_0.6.25       RColorBrewer_1.1-2  ggsignif_0.6.0     
## [16] checkmate_2.0.0     colorspace_1.4-1    sandwich_2.5-1     
## [19] htmltools_0.4.0     Matrix_1.2-18       pkgconfig_2.0.3    
## [22] breakDown_0.1.6     broom_0.5.6         mvtnorm_1.1-0      
## [25] purrr_0.3.4         xtable_1.8-4        scales_1.1.1       
## [28] km.ci_0.5-2         jpeg_0.1-8.1        lava_1.6.7         
## [31] KMsurv_0.1-5        MatrixModels_0.4-1  tibble_3.0.1       
## [34] htmlTable_1.13.3    farver_2.0.3        generics_0.0.2     
## [37] ellipsis_0.3.1      ggpubr_0.2.5        TH.data_1.0-10     
## [40] withr_2.2.0         nnet_7.3-12         magrittr_1.5       
## [43] crayon_1.3.4        polspline_1.1.17    evaluate_0.14      
## [46] MASS_7.3-51.6       nlme_3.1-140        foreign_0.8-76     
## [49] tools_3.6.1         data.table_1.12.8   lifecycle_0.2.0    
## [52] multcomp_1.4-12     stringr_1.4.0       munsell_0.5.0      
## [55] cluster_2.1.0       compiler_3.6.1      survminer_0.4.6    
## [58] rlang_0.4.10        grid_3.6.1          iterators_1.0.12   
## [61] rstudioapi_0.11     htmlwidgets_1.5.1   labeling_0.3       
## [64] base64enc_0.1-3     rmarkdown_2.1       gtable_0.3.0       
## [67] codetools_0.2-16    R6_2.4.1            gridExtra_2.3      
## [70] zoo_1.8-7           knitr_1.28          dplyr_1.0.4        
## [73] survMisc_0.5.5      stringi_1.4.6       parallel_3.6.1     
## [76] Rcpp_1.0.4.6        vctrs_0.3.6         rpart_4.1-15       
## [79] acepack_1.4.1       png_0.1-7           tidyselect_1.1.0   
## [82] xfun_0.12