Unable to load trace after VI

The following simple code does not work.
How can I save the results after VI?
pm.save_trace() and pm.load_trace() works perfectly at MCMC traces.
Thank you in advance.

with pm.model() as model:
    #simple model
with model:
    advi = pm.ADVI()
approx = advi.fit()
mytrace = approx.sample(100000)
# mytrace is <MultiTrace: 1 chains, 100000 iterations, 14 variables>

pm.save_trace(mytrace, "Some_kind_of_directory")
mytrace2 = pm.load_trace("Some_kind_of_directory", model=model)

TypeError: 'NoneType' object is not iterable

1 Like

@ferrine do you know if VI supports this?

I totally thought MultiTrace object of VI is almost identical to that of MCMC…
I was wondering why everyone use pickle.

For most part, it should behave like a MultiTrace…

There should be some problems in my local environment.
I’ll try the code in other pc or google colab.

Thank you as always!

Still, I can`t pm.load_trace()

Windows10, Corei7-8700 (3.2GHz) 
RAM : 16 GB

Python : 3.7.4
pymc3 : 3.6
arviz : 0.5.1

Installed Python and pymc3 by Anaconda, arviz by pip.

I have already tried updating to pymc3 to 3.7, re-installing python and every packages.
Updating to 3.7 caused another problem : slowing pm.traceplot() speed significantly.
It took more than 5 min to plot the result of a very simple model.

Have anyone experienced similar problems?

What error are you seeing?

As for slowness, likely because of an older version of Arviz - did you also upgraded that?

As for pm.save_trace, same error as the post above.
'NoneType' object is not iterable

There was a similar post in September.
Unable to load trace

arviz 0.5.1 seems to be the latest version.
The following codes are the things I tried just now.
Both resulted in slow very pm.traceplot.

1: Reinstall arviz via pip.

pip uninsall arviz
pip install git+https://github.com/arviz-devs/arviz

2: Reinstall arviz via conda

pip uninstall arviz 
conda uninstall arviz
conda install -c conda-forge arviz

Thanks for reporting back - could you share your model as well? We will try to have take a look at the slowness in Arviz.
Meanwhile, @ferrine I think the load_trace breakage is a bug, have you seen it before?

For now, try Saving ADVI results and reloading to save the approximation instead of the trace.

This is my model.
Data size is about 30k, and it is circular data.

with pm.Model() as model:
    kappa_1 = pm.Gamma('kappa_1', 1., 1., )
    mu_1 = pm.VonMises('mu_1', mu=0, kappa=0.5)
    component1 = pm.VonMises.dist(mu=mu_1, kappa=kappa_1)
    kappa_2 = pm.Gamma('kappa_2', 1., 1., )
    mu_2 = pm.VonMises('mu_2', mu=np.pi, kappa=0.5)
    component2 = pm.VonMises.dist(mu=mu_2, kappa=kappa_2)

    w = pm.Dirichlet('w', np.ones_like([1., 1., ]))
    vm = pm.Mixture('vm', w=w, comp_dists=[component1, component2, ],
                    observed=data.values, )
with model:
    fullrank = pm.FullRankADVI()
tracker = pm.callbacks.Tracker(
    mean = fullrank.approx.mean.eval,  # callable that returns mean
    std = fullrank.approx.std.eval  # callable that returns std
approx = fullrank.fit(100000, callbacks=[tracker])

pm.save_trace(trace=approx.sample(100), directory="test", overwrite=True)
trace = pm.load_trace("test", model=model)
# returns 'NoneType' object is not iterable


I found out that this is some kind of memory problem.
Elapsed time is very short but it took few minutes for the plot to appear in my jupyter notebook.
Increasing sampling number to approx.sample(10000) returned MemoryError.
This didn’t happen in pymc3 3.6.0.

import time
start = time.time()
elapsed_time = time.time() - start
print ("elapsed_time:{0}".format(elapsed_time) + "[sec]")
# elapsed_time:0.0029921531677246094[sec]


Plots appeared in few seconds.

import time
start = time.time()
elapsed_time = time.time() - start
print ("elapsed_time:{0}".format(elapsed_time) + "[sec]")
# elapsed_time:0.0[sec]

For now, try Saving ADVI results and reloading to save the approximation instead of the trace.

OK! Thank you!
I really appreciate your reply.
I don’t know how to apologize wasting your time if this is a problem of my local environment…

Happy to help and we appreciate your feedback! @RavinKumar could you check if you could reproduce the slowness issue?
In general, since we are using the same plotting backend, I would expect there is no slowdown, but I could be wrong

Was the slowness reproducible?

I also encounter the same problem. I try running the code on

then do
pm.save_trace(trace, ‘test_trace’)
pm.load_trace(directory=‘test_trace’, model=neural_network)

results in the following error while pm.load_trace
TypeError: ‘NoneType’ object is not iterable

I tried on Colab, my local machine and AWS, all resulting in the same error.