Reusing tuned NUTS steps to save time

I’m running a sampling simulation where you need to sample a fixed model multiple rounds and each round we sample a given number of draws. In order to save as much time as possible, I would like to tune a NUTS step only once and use it for other rounds.

After following the instructions given in Reuse tuning for next sampling call I’m still wondering whether there would be other computations that could be cached/reused. For example, if the sampler creates a computation graph at each sampling round, it may be reasonable to be able to reuse that for other rounds instead of creating it over and over, since it remains the same.

My evidence for this guess is as follows: the time consumed using a tuned NUTS step is not linear with respect to the number of draws (see Sampling using the tuned step section and the variable n_draws below). For example, it took around 1s for sampling only 4 draws, while for 4x1000 draws, we would need around 4s.

import numpy as np
import pymc3 as pm
from timeit import default_timer
from scipy.stats import norm, halfnorm
import matplotlib.pyplot as plt
from pymc3.step_methods.hmc.nuts import NUTS
from pymc3.step_methods.hmc import quadpotential

### Model

alpha, sigma = 1, 1
beta = [1, 2.5]

# Size of dataset
size = 100

# Predictor variable
X1 = np.random.randn(size)
X2 = np.random.randn(size) * 0.2

# Simulate outcome variable
Y = alpha + beta[0] * X1 + beta[1] * X2 + np.random.randn(size) * sigma

basic_model = pm.Model()

with basic_model as m:

    # Priors for unknown model parameters
    alpha = pm.Normal("alpha", mu=0, sigma=10)
    beta = pm.Normal("beta", mu=0, sigma=10, shape=2)
    sigma = pm.HalfNormal("sigma", sigma=1)

    # Expected value of outcome
    mu = alpha + beta[0] * X1 + beta[1] * X2

    # Likelihood (sampling distribution) of observations
    Y_obs = pm.Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)

### Tuning a NUTS step

    n_chains = 4
    init_trace = pm.sample(draws=1000, tune=1000, cores=n_chains)
    cov = np.atleast_1d(pm.trace_cov(init_trace))
    start = list(np.random.choice(init_trace, n_chains))
    potential = quadpotential.QuadPotentialFull(cov)
    step_size = init_trace.get_sampler_stats("step_size_bar")[-1]
    size = m.bijection.ordering.size
    step_scale = step_size * (size ** 0.25)

    # Setting a tuned NUTS step
    step = pm.NUTS(potential=potential, adapt_step_size=False, step_scale=step_scale)
    step.tune = False

### Sampling using the tuned step

    n_draws = 4
    time_zero = default_timer()
    trace = pm.sample(draws=n_draws, step=step, tune=0, cores=n_chains, start=start)
    time_consumed = default_timer() - time_zero