Saving the state of the sampler in pymc3?

I am using Metropolis hasting as the step method of the sampler in my code. I was wondering if it is possible to save the state of the sampler so that it could be loaded later to continue sampling? I know that such functionality is already implemented for pymc2 but I am using pymc3 and due to memory allocation issues I cannot run long chains in one go.
I also know that the trace could be saved and loaded but I need the sampler state to be saved.
I would appreciate any help regarding this.

What do you mean by sampler state? Like scaling of the proposal function?
Also, strongly encourage you to not use random walk base MH, you should use the default (NUTS/HMC)

Thanks for the reply.
what I mean by the sampler state is everything that is required for the sampler to continue from the last state it was at. In other words I do not want to start over again by retuning. If scaling is the only parameter that needs to be tuned and guarantees that the sampler will continue from the last state it was at then how can I retrieve it?

In the Metropolis class the stats is the following dictionary but when I run trace.state_names only “tune” and “accept” are provided.

stats = {
‘tune’: self.tune,
‘scaling’: self.scaling,
‘accept’: np.exp(accept),
‘accepted’: accepted,

As for why I am using MH, I have to say that I started with the default step method (NUTS) but it was sooo slow and could not find one of the easily estimated parameters of the model (maybe it needed longer tuning and drawing) whereas I got pretty good estimations from MH in a short time. But the problem with MH is that it runs out of memory when the number of parameters increases so I need to run it multiple times and it’s important to start from where it was left at.

Right now what I am doing is just saving the last sample of the chain and using it as the starting point of the new chain but as I have to do retuning the new chains are definitely not behaving as the continual of the previous chains.

Any help would be greatly appreciated.

Before I get to the possible code solution, just want to highlight that when NUTS is slow, most likely it is because model miss-specification and a huge warning sign that you should look closer at your model implementation and look for optimization. In most of the applied use case, MH will likely give you either same result but worse effective sample size rhat (i.e., faster run time but slower effective samples) or different but wrong estimation (fast converge to a local mode and return bias result).

Now that hopefully you understand the risk of using MH and my strong discouragement of doing so, here are some possible solution to your question:
In principle you can initialized the step_method by hand step = pm.Metropolis(...), and turn off all tuning, it would goes a bit like this:
initialize your model:

with pm.Model():
    # model specification here
    step = pm.Metropolis(...)
    trace = pm.sample(..., step=step)

afterward, you can either pickle step, or read the parameter from step and initialized a tuned version of pm.Metropolis

1 Like

Thanks. I tried it using the following code but it does not work. Here is what I do:

  1. run the sampler with 5000 tuning steps and 2000 draws then pickle the step object and the last sample of all chains stored in a dictionary.

  2. load the pickled step and the dictionary of starting points then provide the loaded step and starting points to the new sampler and turn tuning off. So it will only draw 2000 samples with no tuning.

for both of the steps above I plot the model logp every 50 steps. Here is how it looks like:

step 1:

step 2:

What is happening is that for step 2 it does not do any good updates and the curve is flattened at the logp close to the last step of the previous sampler. It also complains about increasing the tuning steps.

Also to make sure that if I continue the sampler in step 1 for 2000 more steps the model logp improves and does not flatten I repeated the first step with 5000 tuning and 4000 draws instead of 2000 draws. This helps me to ensure that step 2 actually requires tuning. Here is the model logp plot for 4000 draws in step 1:


as indicated in the plot the model logp gets better and does not flatten. So I think that in step 2 the sampler is not really the continuing the state of sampler in step 1.

I am using a fixed SEED in all the runs so that the behavior of the sampler stays the same across the runs.

Please let me know if I have made any mistake in the code and if there is no mistake in the code then what is missing here.

My code:

draws= 2000
tune=  5000
target_accept= 0.65

if exists_start: # where step 2 happens
    start_dict_lst= load_obj(dir_= dir_, name= 'start_point')
    step= load_obj(dir_= dir_, name= 'step')
    with model:
        model_trace = sample(draws, tune=0, target_accept= target_accept, 
                                cores= cores, step= step, random_seed= SEED, 
                                start= start_dict_lst, discard_tuned_samples= True)
else:  # where step 1 happens
    with model:  
        step= pm.Metropolis()
        model_trace = sample(draws, tune=tune, target_accept= target_accept, 
                                cores= cores, step= step, random_seed= SEED)

sorry I could not figure out how to do proper alignment of the code lines after pasting

Thanks for you time

Looking at the trace, it is apparent that MH is doing a extremely bad job here, I strongly advice you to NOT use it.

you can use markdown:

From a pure coding stand point, here is what you can do:

In [1]: import numpy as np                                                                                                                                                                                                                                    

In [2]: import pymc3 as pm                                                                                                                                                                                                                                    

In [3]: with pm.Model() as m: 
   ...:     x = pm.Normal('x', 0., 5.) 
   ...:     y = pm.Normal('y', x, 1., shape=2) 
   ...:     step = pm.Metropolis() 
   ...:     trace = pm.sample(step=step) 
Multiprocess sampling (2 chains in 2 jobs)
>Metropolis: [y]
>Metropolis: [x]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 1 seconds.% [4000/4000 00:00<00:00 Sampling 2 chains, 0 divergences]
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.

In [4]: tuned_scaling = trace.get_sampler_stats('scaling')[-1].tolist()                                                                                                                                                                                       

In [5]: with m: 
   ...:     step_tuned = [pm.Metropolis(step_last.vars, scaling=s, tune=False) for step_last, s in zip(step.methods, tuned_scaling)] 
   ...:     trace2 = pm.sample(tune=0, step=step_tuned) 
Multiprocess sampling (2 chains in 2 jobs)
>Metropolis: [y]
>Metropolis: [x]
Sampling 2 chains for 0 tune and 1_000 draw iterations (0 + 2_000 draws total) took 0 seconds.| 100.00% [2000/2000 00:00<00:00 Sampling 2 chains, 0 divergences]
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.

In [6]: trace.get_sampler_stats('scaling')                                                                                                                                                                                                                    
array([[1.331   , 1.771561],
       [1.331   , 1.771561],
       [1.331   , 1.771561],
       [1.331   , 1.61051 ],
       [1.331   , 1.61051 ],
       [1.331   , 1.61051 ]])

In [7]: trace2.get_sampler_stats('scaling')                                                                                                                                                                                                                   
array([[1.331  , 1.61051],
       [1.331  , 1.61051],
       [1.331  , 1.61051],
       [1.331  , 1.61051],
       [1.331  , 1.61051],
       [1.331  , 1.61051]])



Thanks for providing me with the code. But as I mentioned in my earlier replies when I print the stats of the trace (trace.stat_names) it only shows ‘tune’ and ‘accept’ so if I try to get ‘scaling’ from the trace I get the error “KeyError: ‘Unknown sampler statistic scaling’”. I am using PyMC3.5.

Regarding switching to NUTS/HMC I need to figure out where the model misspecification is and how to fix it. So please provide me with some resources on this.
Right now in the beginning of the sampling it starts by 3 or 4 draws/s and increases to 12 draws/s. Depending on which variables are updated this behavior could change. For example I have some weights in my model that come from Dirichlet distribution and if I fix all the other variables (known to the model/sampler) and ask the sampler to update these weights only then the sampling rate is so low. Sometimes depending on the number of variables involved its sampling rate stays around 2 or 3 draws/s.
W= pm.Dirichlet(‘W’, a= mu_w.T, shape= (N, C), transform=t_stick_breaking(eps))

Regardless of what step method is used I have to solve the memory allocation issue so I also need to run NUTS/HMC multiples times. I assume it is going to be more challenging than MH.

Hi @sabagh,
For background about MCMC algorithms and the (good) points @junpenglao is making, I strongly recommend Statistical Rethinking and Bayesian Analysis with Python, respectively by Richard McElreath and our very own Osvaldo Martin – here is the educational resources page on PyMC website.
You can also watch the accompanying video of Richard’s course, very interesting :video_camera:

I also recommend this in-depth article about MCMC in practice, although it’s more theoretically heavy.

Finally, I’d advise updating to the latest stable version if you can, i.e 3.9.3.
Hope it helps :vulcan_salute:

1 Like

Thanks. I will definitely go through them. @AlexAndorra

@junpenglao the code was really helpful. Just a minor point. In line 4 where the scaling is retrieved only the final chain is considered so when we use the scaling for the two new chains only one of them works properly, i.e. is actually continuing the previous chain. How can I provide the sampler with two sets of scalings. So instead of having:

tuned_scaling = trace.get_sampler_stats(‘scaling’)[-1].tolist()

we should have:
chains_scalings= trace.get_sampler_stats(‘scaling’, combine= False)
chain_1_scalings= chains_scalings[0][-1].tolist()
chain_2_scalings= chains_scalings[1][-1].tolist()

But I do not know how to initialize the step method/sampler with two sets of scalings.
Please let me know if this is even possible.

it would be great if other developers could help me with this issue too. I am referring to my latest reply right above.

I also need to know how to do the same thing when NUTS is used. In other words what needs to be saved in order to continue sampling (using NUTS) at a later time.

@twiecki and @AlexAndorra your assistance in this matter would be greatly appreciated.