The following simple code does not work.
How can I save the results after VI? pm.save_trace() and pm.load_trace() works perfectly at MCMC traces.
Thank you in advance.
with pm.model() as model:
#simple model
with model:
advi = pm.ADVI()
approx = advi.fit()
mytrace = approx.sample(100000)
# mytrace is <MultiTrace: 1 chains, 100000 iterations, 14 variables>
pm.save_trace(mytrace, "Some_kind_of_directory")
mytrace2 = pm.load_trace("Some_kind_of_directory", model=model)
returns TypeError: 'NoneType' object is not iterable
Installed Python and pymc3 by Anaconda, arviz by pip.
I have already tried updating to pymc3 to 3.7, re-installing python and every packages.
Updating to 3.7 caused another problem : slowing pm.traceplot() speed significantly.
It took more than 5 min to plot the result of a very simple model.
Thanks for reporting back - could you share your model as well? We will try to have take a look at the slowness in Arviz.
Meanwhile, @ferrine I think the load_trace breakage is a bug, have you seen it before?
This is my model.
Data size is about 30k, and it is circular data.
with pm.Model() as model:
kappa_1 = pm.Gamma('kappa_1', 1., 1., )
mu_1 = pm.VonMises('mu_1', mu=0, kappa=0.5)
component1 = pm.VonMises.dist(mu=mu_1, kappa=kappa_1)
kappa_2 = pm.Gamma('kappa_2', 1., 1., )
mu_2 = pm.VonMises('mu_2', mu=np.pi, kappa=0.5)
component2 = pm.VonMises.dist(mu=mu_2, kappa=kappa_2)
w = pm.Dirichlet('w', np.ones_like([1., 1., ]))
vm = pm.Mixture('vm', w=w, comp_dists=[component1, component2, ],
observed=data.values, )
with model:
fullrank = pm.FullRankADVI()
tracker = pm.callbacks.Tracker(
mean = fullrank.approx.mean.eval, # callable that returns mean
std = fullrank.approx.std.eval # callable that returns std
)
approx = fullrank.fit(100000, callbacks=[tracker])
pm.save_trace(trace=approx.sample(100), directory="test", overwrite=True)
trace = pm.load_trace("test", model=model)
# returns 'NoneType' object is not iterable
pymc3:3.7
I found out that this is some kind of memory problem.
Elapsed time is very short but it took few minutes for the plot to appear in my jupyter notebook.
Increasing sampling number to approx.sample(10000) returned MemoryError.
This didn’t happen in pymc3 3.6.0.
Happy to help and we appreciate your feedback! @RavinKumar could you check if you could reproduce the slowness issue?
In general, since we are using the same plotting backend, I would expect there is no slowdown, but I could be wrong