Reuse tuning for next sampling call

I am currently trying to use a PyMC3 model in an online learning setting. The data arrives in single instances and I need to update the model on the complete data in each iteration.

Since the posterior distribution only changes slightly and I want to avoid having to wait for the tuning steps to complete, my plan was to use the previous trace as follows:

# Initial sampling & tuning step:
with model:
    step = pm.NUTS()
    trace = pm.sample(draws=100, step=step, tune=1000)


# Online sampling step:
with model:
    trace = pm.sample(draws=25, step=step, trace=trace, tune=0)

But it appears, that the sample function does no reuse the initially trained NUTS. Is there a way to achieve what I’m looking for?

Note: I am not looking for a solution on how to update the priors.

The best way to do is to use the trace from previous sample to initialized the NUTS sampler (we have something similar in https://github.com/pymc-devs/pymc3/blob/f375f27fdcf6568dd662d8db4474497b9bda1f58/pymc3/sampling.py#L1477-L1488)

with model1:
    init_trace = pm.sample(draws=1000, tune=1000)

from pymc3.step_methods.hmc import quadpotential
cov = np.atleast_1d(pm.trace_cov(init_trace))
start = list(np.random.choice(init_trace, chains))
potential = quadpotential.QuadPotentialFull(cov)

with pm.Model() as model_new: # reset model. If you use theano.shared you can also update the value of model1 above
    step = pm.NUTS(potential=potential)
    trace = pm.sample(1000, tune=100, step=step) # good to still do a bit of tuning
2 Likes

Couldn’t you also just recycle the potential from a previous run NUTS object?

That’s an option too. However since pm.sample() only returns the trace, user would need to set up the potential and step method by hand anyway (and also set up the right quadpotential method to replicate the default initialization) - it comes down around the same among of code.

Also you need to reset some properties in potential to turn back on the tuning etc

That tuning also depends on whether your new model has new data (re: “online learning”), in that it may change the scaling of your posterior.

I don’t know how much the tuning depends on the (new) data, but my intuition tells me “maybe a lot”.

1 Like

Great ideas so far - will try the approach by junpenglao and see how much I have to retune.

I did a sanity check to see whether this method can be used to simply resume sampling:

with model:  # initial sample
    step = pm.NUTS()
    trace = pm.sample(draws=100, step=step, tune=3000, cores=n_chains)

from pymc3.step_methods.hmc import quadpotential
with model:
    cov = np.atleast_1d(pm.trace_cov(trace))
    start = list(np.random.choice(trace, n_chains))
    potential = quadpotential.QuadPotentialFull(cov)

with pm.Model() as model2:
    # Reset model here using the same observed data
    step = pm.NUTS(potential=potential)
    trace2 = pm.sample(draws=100, step=step, tune=0, cores=n_chains, start=start)

Which does not work at all - I get acceptance probabilities of 10-40% (for the second sampling). Which other tuning information do I need to carry over?

Try also specifying the max_treedepth and step_size

I am not sure whether I am doing it correctly (it still does not work):

with pm.Model() as model2:
    # reset model
    step2 = pm.NUTS(potential=potential, max_treedepth=11)
    step2.step_adapt._log_bar = np.log(trace['step_size_bar'][-1])
    trace2 = pm.sample(draws=1000, step=step2, tune=0, cores=n_chains, start=start)

A few things I notice here:

  • The step sizes in step.step_adapt.stats() are never updated. Could it be that the step object is copied internally?
  • In a similar vein: The step size I enter manually (above) is not used. I can see this since step.step_adapt._tuned_stats is always [].

edit: Since this procedure is basically what is done in init_nuts in sampling.py. How is this working correctly?

Thanks for reporting back - you are right there are actually more tuning as NUTS (and HMC) also has dual-averaging for step size. The difficulty here is that

  1. you need to also set the step size, otherwise the default is not good once you turn off tuning
  2. step size is not directly set able during init (we should probably change that):

To make it work correctly, you need to compute the right step_scale and pass it to init. So please find a minimal working example:

n_chains = 4

with pm.Model() as m:
    x = pm.Normal('x', shape=10)
    trace1 = pm.sample(1000, tune=1000, cores=n_chains)

from pymc3.step_methods.hmc import quadpotential
with m:
    cov = np.atleast_1d(pm.trace_cov(trace1))
    start = list(np.random.choice(trace1, n_chains))
    potential = quadpotential.QuadPotentialFull(cov)
    step_size = trace1.get_sampler_stats('step_size_bar')[-1]
    
    size = m.bijection.ordering.size
    step_scale = step_size * (size ** 0.25)

with pm.Model() as m2:
    x = pm.Normal('x', shape=10)
    step = pm.NUTS(potential=potential, 
                   adapt_step_size=False, 
                   step_scale=step_scale)
    step.tune = False
    trace2 = pm.sample(draws=100, step=step, tune=0, cores=n_chains, start=start)

If you are using the same model (i.e., no re-initializing), you can do what @twiecki said, however it seems the sampler is reset somewhere, which means you need to reset a bunch of stuff as well:

n_chains = 4

with pm.Model() as m:
    x = pm.Normal('x', shape=10)
    # init == 'jitter+adapt_diag'
    start = []
    for _ in range(n_chains):
        mean = {var: val.copy() for var, val in m.test_point.items()}
        for val in mean.values():
            val[...] += 2 * np.random.rand(*val.shape) - 1
        start.append(mean)
    mean = np.mean([m.dict_to_array(vals) for vals in start], axis=0)
    var = np.ones_like(mean)
    potential = quadpotential.QuadPotentialDiagAdapt(
        m.ndim, mean, var, 10)
    step = pm.NUTS(potential=potential)
    trace1 = pm.sample(1000, step=step, tune=1000, cores=n_chains)

with m: # need to be the same model
    step_size = trace1.get_sampler_stats('step_size_bar')[-1]
    from pymc3.step_methods import step_sizes
    step.tune = False
    step.step_adapt = step_sizes.DualAverageAdaptation(
            step_size, step.target_accept, 0.05, .75, 10
        )
    trace2 = pm.sample(draws=100, step=step, tune=0, cores=n_chains)