4

The xgboost package enables survival modeling using parameter arguments: objective = "survival:cox" and eval_metric = "cox-nloglik".

The predict method for the resulting model only outputs risk scores (same as type = "risk" in the survival::coxph function in r).

How do I use xgboost to predict entire survival curves?

Iyar Lin
  • 749
  • 4
  • 17

2 Answers2

10

The proportional hazard model assumes hazard rates of the form: $h(t|X) = h_0(t) \cdot risk(X)$ where usually $risk(X) = exp(X\beta)$. The xgboost predict method returns $risk(X)$ only. What we can do is use the survival::basehaz function to find $h_0(t)$.

Problem is it's not "calibrated" to the actual baseline hazard rate computed in xgboost. What we can do is find some constant $C$ that minimizes the ibrier score between the sample observed death/censorship times and $h_0(t) \cdot risk(X) \cdot C$.

I've implemented this approach in a tiny R package I've written.

Iyar Lin
  • 749
  • 4
  • 17
  • How to calculate C-index in this setting? Can your package accomplish this? Thanks! – Tommy Mar 31 '22 at 14:07
  • My package does not include that functionality. I would recommend using either the `pec` or `riskRegression` packages – Iyar Lin Apr 01 '22 at 17:34
  • How do censorship times contribute to the ibrier score? – 42- Jun 27 '22 at 22:46
  • Not sure I understand the question. The ibrier score is a function of the estimated and observed survival functions, in which censorship times also take a role. – Iyar Lin Jun 29 '22 at 10:13
1

The solution to use survival::basehaz() with a coxph model and estimate a constant C, as implemented by survXgboost should be used with caution. When you have binary predictors, coxph coefficients explode, leading to really overestimated baseline hazard, the constant C will not do much and the performance of xgboost will look much worse than what it really is.

The gbm package has a function gbm::basehaz which skips the model, avoiding the compatibility problem that you have in survival::basehaz(), and uses the predict() results to estimate the baseline hazard. It is more reliable and the (cumulative) baseline hazard is as expected.

maribuon
  • 11
  • 3