Serializing BART models for out of sample data

Good Afternoon!

After training a BART model, I have had some difficulty saving and then loading the model in (a separate) notebook.

What I have tried:

  1. Written a class that fits a BART model, computes diagnostics, saves idata and predicts on test data in Mlflow (within the same instance of the class). I am not using mlflow.pyfunc at this time, (mostly because I am still testing what I’d like mfllow to track)
  2. Saved the model instance and idata using cloud pickle
  3. loaded the idata and model in a seperate notebook
  4. Created an out of sample data set and using pm.set_data, sampled from the posterior predictive of the model

To accomplish 4, I have set predictions = True in pm.sample_posterior_predictive, and have set my out of sample covariate frame to my ‘X’ variable, and for my 'Y ’ variable, I have just passed a dummy vector of length equal to my X variable, referencing this: Out of model predictions with PyMC and Categorical BART with Out of Sample Predictions

Performance is…inconsistent here, sometimes the load process breaks and other times it seems to work fine.

I have seen some alternative attempts to serialize the model, including GitHub - CDCgov/BART-Survival However, the methods utilized here to directly save and load the tree structure do not seem to exist.

Additionally, I have attempted to fit into a case where I load and sample from the trace, as in the radon example, but I am not sure where to start with extracting terminal nodes from trees.

Thank you very much for your time!

Apologies in advance for the preemptive @aloctavodia, but do you have any recommendations here? I’ve noticed you’ve done a lot of sweet work bringing BART to pymc and am not sure if there is a use case I have not seen on this forum yet that is similar to mine

Have you tried this?

Hi!

I’ve tried to incorporate this into my workflow (along with mlflow.pyfunc), but the main challenge thus far has been how to define this in a manner that allows me to store and call the tree structure (whch would be analogous to the mean variables in the example). I could be very wrong, but in order for me to use something like this I would need to extract all tree leafs first.

The bart survival package claims that you can do this with the model.f.owner method, but that seems to not exist in the current version of pymc-bart