Hi, I have build a bart model for regression task. And I plan to deploy the model to a web platform. But I found a problem is that the Pymc can only save trace data (az.to_netcdf). After reloading the trace data (az.from_netcdf), the model still needs to be retrained. Due to the large amount of training data, it takes more than 10 minutes to train the model each time. This obviously cannot be configured on the web platform for practical application.
I have also try this solution method(“Save and Load a BART model”). But it does not work.
So I want to ask if there is a general way to call the pymc model without retraining the model. Thank you very much, looking forward to your reply.
My bart code is as follows:
# build model
with pm.Model() as E_ulti:
X = pm.MutableData("X", X_train)
Y = pm.MutableData("Y", Y_train)
α = pm.Exponential("α", 1)
μ = pmb.BART("μ", X, np.log(Y_train),m=200,split_rules=split_rules)
y = pm.Normal("y", mu=pm.math.exp(μ), sigma=α, observed=Y, shape=μ.shape)
idata_ulti = pm.sample(random_seed=RANDOM_SEED)
# save trace
az.to_netcdf(idata_ulti, 'trace_c.nc')
# load trace
trace_c = az.from_netcdf('trace_c.nc')
# sample predictive
with Em_ulti:
X.set_value(xtest_pd)
posterior_predictive_xtest = pm.sample_posterior_predictive(trace=trace_c, random_seed=RANDOM_SEED)
# check result
az.summary(posterior_predictive_xtest, var_names="y", kind="stats", round_to=4)