Hello,
I am curious if anyone has had success implementing a survival model using pymc-bart, similar to what was done in this paper Nonparametric survival analysis using Bayesian Additive Regression Trees (BART) - PMC and their corresponding R package BART.
If I am understanding the paper correctly, I believe I could replicate in pymc by simply creating the long-form discrete time dataset and then modeling with bart and a bernoulli likelihood. Like below
with pm.Model() as bart3:
x_data = pm.MutableData("x", trainx2)
f = pmb.BART("f", X = x_data, Y = delta2, m=100)
z = pm.Deterministic("z", f + off) # off is an offset number calculated from qnorm(mean(delta))
mu = pm.Deterministic("mu", pm.math.invprobit(z))
y_pred = pm.Bernoulli("y_pred", p=(mu), observed=delta2, shape=x_data.shape[0])
smp3 = pm.sample(random_seed=2, draws=100)
Then as a simple check I just generated predictions on a test dataset of all the train data with time points extended through the max time.
with bart3:
pm.set_data({"x":testx2})
pp3 = pm.sample_posterior_predictive(smp3, var_names= ["y_pred", "f", "z", "mu"])
the ppv “mu” can be used as the risk at each time point and from that the survival probability is generated.
And to continue the simple check I calculated the average Survival probability for each timepoint across the test dataset and compared against the predicted average for the dataset with a RSF, R(BART) and KPM estimates.
Surprisingly, this worked relatively well in my first attempts.
However, my pymc Survival predictions seems a little off from the compared models.
Does anyone have any ideas on if I am missing anything from my model set-up or has anyone had success in a similar survival+bart model?
Thanks!