Very High Rhat when fitting Pymc-BART model

Hi,

I’ve been trying to use pymc-BART for the first time.

I have been using some time to event data I have to try the BART model out. I have transformed the data into 50,000 person-interval slices of 1 year duration. The person-intervals are censored after incident event.
Each slice provides information on a given person at a given interval since study start. This information is: the number of intervals since study start, the age, sex, and various other continuous measurements updated to the start of the interval slice. The fraction of the time in the slice (1, if no outcome event occurred, and the fraction of time through the slice if there was an event in the slice). A binary outcome of whether an incident event occurred within the slice. I have approximately 20 features/covariates and try to fit the following model with Poisson likelihood to the person-interval data to the outcome of incident event count (0 or 1 in this case).

with pm.Model() as model_t1dcvd:
    loghaz = pmb.BART("loghaz", X=x_data, Y=log_y_data, m=100)
    mean = pm.Deterministic("mean", np.exp(loghaz) + offset)
    y_pred = pm.Poisson("y_pred", mu=mean, observed=y_data)
    idata_t1dcvd = pm.sample(1000, tune=1000, pgbart={'num_particles':30, 'batch':(50, 50)}, cores=4)

Here log_y_data is 0 for slices where an event occurred (log (1)) and 0.1 for slices where it does not.

I get very high r-hat values (median > 1.7) when I fit this model.
I’ve tried increasing num_particles and batch but not seen Rhat come down substantially yet but running times have increased significantly.
Are there other strategies available to improve mixing? How sensitive is pmb.BART to the Y argument with respect to mixing?

Thanks

Welcome!

Perhaps @aloctavodia might be able to weigh in on this?

Did you try Y=np.log(y_data + 1)?

Not sure I am understanding your description of y_data. Is y_data a vector of 0 and 1? If so why not use a binomial/bernoulli likelihood.

Regarding your question about the sensitivity to mixing, I can not provide a definitive answer yet, but it is something I am exploring. The main reason for the Y argument is to have some guidance on how to initialize the BART variable. In my experiments the results (in terms of the fit) are quite robust to the values passed to that argument. For many cases passing the observed data or some transformation seems to be the right choice. Nevertheless, I am exploring an alternative API that results in better results, more user flexibility, and remains simple. If you can share your data in private with me that could help to me achieve that goal.

1 Like

Yes, the y_data is a vector of 0s and 1s.

I am trying to fit a piecewise-constant proportional hazard using the trick of fitting a Poisson model to person-interval data. (in the process of searching for a reference to this I realised theres a pymc notebook on fitting such models without BART Bayesian Survival Analysis — PyMC example gallery ). I will look at that notebook to check theres not a problem with how I’ve specified the rest of the model that is unrelated to the BART component.

The specific data I am unable to share, but I will work on switching to another dataset that I can share with you.

OK, I will check that example. I would appreciate it if you could share a dataset.

Thanks,

I created a simulated data set. I’ve put it and the code I ran in a gist here: BART example · GitHub

The R-hat are all approximately 2 when I run the try_bart_sim.py script.

After looking at the example notebook, I realised I made an error with respect to the offset term (I didn’t notice I was accidentally adding rather than multiplying the offset).
My Rhats are obviously not nearly as big when correcting my error (apologies for not spotting this before posting). I’ve updated my gist with a corrected script.
Currently ~ 90% of the Rhat values <= 1.01.
However I do still get the warning message “The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See [1903.08008] Rank-normalization, folding, and localization: An improved $\widehat{R}$ for assessing convergence of MCMC for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See [1903.08008] Rank-normalization, folding, and localization: An improved $\widehat{R}$ for assessing convergence of MCMC for details”

Good to hear you found that issue.

There are many reports in the literature (and private talks) pointing out good predictions with BART models (in general not just PyMC-BART) even in the presence of convergence issues. In my experience, it is enough to use pmb.plot_convergence and check that “most” of the ecdf for r-hat is below the threshold and most of the ecdf for ESS is above the threshold.