BART Survival, has anyone attempted?

I am curious if anyone has had success implementing a survival model using pymc-bart, similar to what was done in this paper Nonparametric survival analysis using Bayesian Additive Regression Trees (BART) - PMC and their corresponding R package BART.

If I am understanding the paper correctly, I believe I could replicate in pymc by simply creating the long-form discrete time dataset and then modeling with bart and a bernoulli likelihood. Like below

with pm.Model() as bart3:
    x_data = pm.MutableData("x", trainx2)

    f = pmb.BART("f", X = x_data, Y = delta2, m=100)
    z = pm.Deterministic("z", f + off) # off is an offset number calculated from qnorm(mean(delta))

    mu = pm.Deterministic("mu", pm.math.invprobit(z))
    y_pred = pm.Bernoulli("y_pred", p=(mu), observed=delta2, shape=x_data.shape[0])
    smp3 = pm.sample(random_seed=2, draws=100)

Then as a simple check I just generated predictions on a test dataset of all the train data with time points extended through the max time.

with bart3:
    pp3 = pm.sample_posterior_predictive(smp3, var_names= ["y_pred", "f", "z", "mu"])

the ppv “mu” can be used as the risk at each time point and from that the survival probability is generated.

And to continue the simple check I calculated the average Survival probability for each timepoint across the test dataset and compared against the predicted average for the dataset with a RSF, R(BART) and KPM estimates.

Surprisingly, this worked relatively well in my first attempts.

However, my pymc Survival predictions seems a little off from the compared models.
Does anyone have any ideas on if I am missing anything from my model set-up or has anyone had success in a similar survival+bart model?