Custom_predict_for_survival_models.Rmd
This vignette contains example predict functions for survival models. Some functions are already implemented. Therefore, for some models there is no need to specify predict function.
data(pbc, package = "randomForestSRC")
pbc <- pbc[complete.cases(pbc),]
pbc$sex <- as.factor(pbc$sex)
pbc$stage <- as.factor(pbc$stage)
Currently implemented model classes. Objects listed below don’t need specified predict function.
aalen
riskRegression
cox.aalen
cph
coxph
matrix
selectCox
pecCforest
prodlim
psm
survfit
pecRpart
pecCtree
set.seed(1024)
library(rms)
library(survxai)
cph_model <- cph(Surv(days/365, status)~., data = pbc, surv = TRUE, x = TRUE, y=TRUE)
surve_cph <- explain(model = cph_model,
data = pbc[,-c(1,2)], y = Surv(pbc$days/365, pbc$status))
Predict function for class rfsrc
is not implemented. Therefore, custom predict function should be provided.
library(prodlim)
library(randomForestSRC)
predict_rf <- function(object, newdata, times, ...){
f <- sapply(newdata, is.integer)
cols <- names(which(f))
object$xvar[cols] <- lapply(object$xvar[cols], as.integer)
ptemp <- predict(object,newdata=newdata,importance="none")$survival
pos <- prodlim::sindex(jump.times=object$time.interest,eval.times=times)
p <- cbind(1,ptemp)[,pos+1,drop=FALSE]
return(p)
}
pbc$year <- pbc$days/365
rf_model <- rfsrc(Surv(year, status)~., data = pbc[,-1])
surve_rf <- explain(model = rf_model,
data = pbc[,-c(1,2,20)], y = Surv(pbc$year, pbc$status),
predict_function = predict_rf)
Predict function for class survreg
is not implemented. Therefore, custom predict function should be provided.
library(survival)
predict_reg <- function(model, newdata, times){
times <- sort(times)
vars <- all.vars(model$call[[2]][[2]])
n_vars <- which(colnames(newdata) %in% vars)
if(length(n_vars)>0){
newdata <- newdata[,-c(n_vars)]
}
model$x <- model.matrix(~., newdata)
res <- matrix(ncol = length(times), nrow = nrow(newdata))
for(i in 1:nrow(newdata)) {
res[i,] <- cfc.survreg.survprob(t = times, args = model, n = i)
}
return(res)
}
reg_model <- survreg(Surv(year, status)~., data = pbc[,-1], x = TRUE)
surve_reg <- explain(model = rf_model,
data = pbc[,-c(1,2,20)],
y = Surv(pbc$year, pbc$status),
predict_function = predict_reg)