I’m running a sampling simulation where you need to sample a fixed model multiple rounds and each round we sample a given number of draws. In order to save as much time as possible, I would like to tune a NUTS step only once and use it for other rounds.
After following the instructions given in Reuse tuning for next sampling call I’m still wondering whether there would be other computations that could be cached/reused. For example, if the sampler creates a computation graph at each sampling round, it may be reasonable to be able to reuse that for other rounds instead of creating it over and over, since it remains the same.
My evidence for this guess is as follows: the time consumed using a tuned NUTS step is not linear with respect to the number of draws (see Sampling using the tuned step section and the variable
n_draws below). For example, it took around 1s for sampling only 4 draws, while for 4x1000 draws, we would need around 4s.
import numpy as np import pymc3 as pm from timeit import default_timer from scipy.stats import norm, halfnorm import matplotlib.pyplot as plt from pymc3.step_methods.hmc.nuts import NUTS from pymc3.step_methods.hmc import quadpotential ### Model alpha, sigma = 1, 1 beta = [1, 2.5] # Size of dataset size = 100 # Predictor variable X1 = np.random.randn(size) X2 = np.random.randn(size) * 0.2 # Simulate outcome variable Y = alpha + beta * X1 + beta * X2 + np.random.randn(size) * sigma basic_model = pm.Model() with basic_model as m: # Priors for unknown model parameters alpha = pm.Normal("alpha", mu=0, sigma=10) beta = pm.Normal("beta", mu=0, sigma=10, shape=2) sigma = pm.HalfNormal("sigma", sigma=1) # Expected value of outcome mu = alpha + beta * X1 + beta * X2 # Likelihood (sampling distribution) of observations Y_obs = pm.Normal("Y_obs", mu=mu, sigma=sigma, observed=Y) ### Tuning a NUTS step n_chains = 4 init_trace = pm.sample(draws=1000, tune=1000, cores=n_chains) cov = np.atleast_1d(pm.trace_cov(init_trace)) start = list(np.random.choice(init_trace, n_chains)) potential = quadpotential.QuadPotentialFull(cov) step_size = init_trace.get_sampler_stats("step_size_bar")[-1] size = m.bijection.ordering.size step_scale = step_size * (size ** 0.25) # Setting a tuned NUTS step step = pm.NUTS(potential=potential, adapt_step_size=False, step_scale=step_scale) step.tune = False ### Sampling using the tuned step n_draws = 4 time_zero = default_timer() trace = pm.sample(draws=n_draws, step=step, tune=0, cores=n_chains, start=start) time_consumed = default_timer() - time_zero print(time_consumed)