I am currently trying to use a PyMC3 model in an online learning setting. The data arrives in single instances and I need to update the model on the complete data in each iteration.
Since the posterior distribution only changes slightly and I want to avoid having to wait for the tuning steps to complete, my plan was to use the previous trace as follows:
with model1:
init_trace = pm.sample(draws=1000, tune=1000)
from pymc3.step_methods.hmc import quadpotential
cov = np.atleast_1d(pm.trace_cov(init_trace))
start = list(np.random.choice(init_trace, chains))
potential = quadpotential.QuadPotentialFull(cov)
with pm.Model() as model_new: # reset model. If you use theano.shared you can also update the value of model1 above
step = pm.NUTS(potential=potential)
trace = pm.sample(1000, tune=100, step=step) # good to still do a bit of tuning
That’s an option too. However since pm.sample() only returns the trace, user would need to set up the potential and step method by hand anyway (and also set up the right quadpotential method to replicate the default initialization) - it comes down around the same among of code.
Also you need to reset some properties in potential to turn back on the tuning etc
I did a sanity check to see whether this method can be used to simply resume sampling:
with model: # initial sample
step = pm.NUTS()
trace = pm.sample(draws=100, step=step, tune=3000, cores=n_chains)
from pymc3.step_methods.hmc import quadpotential
with model:
cov = np.atleast_1d(pm.trace_cov(trace))
start = list(np.random.choice(trace, n_chains))
potential = quadpotential.QuadPotentialFull(cov)
with pm.Model() as model2:
# Reset model here using the same observed data
step = pm.NUTS(potential=potential)
trace2 = pm.sample(draws=100, step=step, tune=0, cores=n_chains, start=start)
Which does not work at all - I get acceptance probabilities of 10-40% (for the second sampling). Which other tuning information do I need to carry over?
Thanks for reporting back - you are right there are actually more tuning as NUTS (and HMC) also has dual-averaging for step size. The difficulty here is that
you need to also set the step size, otherwise the default is not good once you turn off tuning
step size is not directly set able during init (we should probably change that):
To make it work correctly, you need to compute the right step_scale and pass it to init. So please find a minimal working example:
n_chains = 4
with pm.Model() as m:
x = pm.Normal('x', shape=10)
trace1 = pm.sample(1000, tune=1000, cores=n_chains)
from pymc3.step_methods.hmc import quadpotential
with m:
cov = np.atleast_1d(pm.trace_cov(trace1))
start = list(np.random.choice(trace1, n_chains))
potential = quadpotential.QuadPotentialFull(cov)
step_size = trace1.get_sampler_stats('step_size_bar')[-1]
size = m.bijection.ordering.size
step_scale = step_size * (size ** 0.25)
with pm.Model() as m2:
x = pm.Normal('x', shape=10)
step = pm.NUTS(potential=potential,
adapt_step_size=False,
step_scale=step_scale)
step.tune = False
trace2 = pm.sample(draws=100, step=step, tune=0, cores=n_chains, start=start)
If you are using the same model (i.e., no re-initializing), you can do what @twiecki said, however it seems the sampler is reset somewhere, which means you need to reset a bunch of stuff as well:
n_chains = 4
with pm.Model() as m:
x = pm.Normal('x', shape=10)
# init == 'jitter+adapt_diag'
start = []
for _ in range(n_chains):
mean = {var: val.copy() for var, val in m.test_point.items()}
for val in mean.values():
val[...] += 2 * np.random.rand(*val.shape) - 1
start.append(mean)
mean = np.mean([m.dict_to_array(vals) for vals in start], axis=0)
var = np.ones_like(mean)
potential = quadpotential.QuadPotentialDiagAdapt(
m.ndim, mean, var, 10)
step = pm.NUTS(potential=potential)
trace1 = pm.sample(1000, step=step, tune=1000, cores=n_chains)
with m: # need to be the same model
step_size = trace1.get_sampler_stats('step_size_bar')[-1]
from pymc3.step_methods import step_sizes
step.tune = False
step.step_adapt = step_sizes.DualAverageAdaptation(
step_size, step.target_accept, 0.05, .75, 10
)
trace2 = pm.sample(draws=100, step=step, tune=0, cores=n_chains)