Thanks for reporting back - you are right there are actually more tuning as NUTS (and HMC) also has dual-averaging for step size. The difficulty here is that
- you need to also set the step size, otherwise the default is not good once you turn off tuning
- step size is not directly set able during
init
(we should probably change that):
To make it work correctly, you need to compute the right step_scale
and pass it to init
. So please find a minimal working example:
n_chains = 4
with pm.Model() as m:
x = pm.Normal('x', shape=10)
trace1 = pm.sample(1000, tune=1000, cores=n_chains)
from pymc3.step_methods.hmc import quadpotential
with m:
cov = np.atleast_1d(pm.trace_cov(trace1))
start = list(np.random.choice(trace1, n_chains))
potential = quadpotential.QuadPotentialFull(cov)
step_size = trace1.get_sampler_stats('step_size_bar')[-1]
size = m.bijection.ordering.size
step_scale = step_size * (size ** 0.25)
with pm.Model() as m2:
x = pm.Normal('x', shape=10)
step = pm.NUTS(potential=potential,
adapt_step_size=False,
step_scale=step_scale)
step.tune = False
trace2 = pm.sample(draws=100, step=step, tune=0, cores=n_chains, start=start)
If you are using the same model (i.e., no re-initializing), you can do what @twiecki said, however it seems the sampler is reset somewhere, which means you need to reset a bunch of stuff as well:
n_chains = 4
with pm.Model() as m:
x = pm.Normal('x', shape=10)
# init == 'jitter+adapt_diag'
start = []
for _ in range(n_chains):
mean = {var: val.copy() for var, val in m.test_point.items()}
for val in mean.values():
val[...] += 2 * np.random.rand(*val.shape) - 1
start.append(mean)
mean = np.mean([m.dict_to_array(vals) for vals in start], axis=0)
var = np.ones_like(mean)
potential = quadpotential.QuadPotentialDiagAdapt(
m.ndim, mean, var, 10)
step = pm.NUTS(potential=potential)
trace1 = pm.sample(1000, step=step, tune=1000, cores=n_chains)
with m: # need to be the same model
step_size = trace1.get_sampler_stats('step_size_bar')[-1]
from pymc3.step_methods import step_sizes
step.tune = False
step.step_adapt = step_sizes.DualAverageAdaptation(
step_size, step.target_accept, 0.05, .75, 10
)
trace2 = pm.sample(draws=100, step=step, tune=0, cores=n_chains)