How to save BART idata?

Hi!

I’ve encountered an error when try to use to_netcdf() to save the idata of a BART model.

idata.to_netcdf(model_fpath)

ValueError: unable to infer dtype on variable ‘bart_trees’; xarray cannot serialize arbitrary Python objects

I’m currently using cloudpickle to save the idata instead, but I’m not sure if that’s the correct way of doing so.

# save
with open(model_fpath, 'wb') as buff:
    cloudpickle.dump({'model': model, 'trace': idata}, buff)

# load
with open(model_fpath, 'rb') as buff:
      saved_model = cloudpickle.load(buff)
   
model_copy = saved_model['model']    
idata_copy = saved_model['trace']

What would be the recommended way of saving an idata with bart_trees?

Thank you for any help!

Hi again, yiyi-z.

If you want to save the idata for prediction in future, try these codes that I tried before to do this.

Let say you have an idata as this.

idata = pm.sample(...)

You can save like this.

import pickle

with open('my_bart_model.pkl', 'wb') as p:
    pickle.dump(idata, p)

When you want to load and use it.

import pickle

idata = pickle.load(open('my_bart_model.pkl', 'rb'))
samples_from_posterior = pmb.predict(idata, rng, newDataFrame.values, 100)

Thanks.

1 Like

Thank you y_m! :laughing:

Just for reference, you can now use

idata.to_netcdf()
1 Like