Harnessing multiple cores to speed up fits with small number of chains

I found this very helpful thread from a few years ago

So I think we have a winner with the following (credit to @junpenglao):

import pymc3 as pm
import numpy as np
from pymc3.step_methods import step_sizes
from pymc3.step_methods.hmc import quadpotential

import warnings
warnings.filterwarnings("ignore");

data = np.random.normal(0,1,1000)

n_tune_chains = 1
n_sample_chains = 5

with pm.Model() as m:
    
    mu = pm.Normal('mu',0,1)
    sigma = pm.HalfNormal('sigma',1)
    y = pm.Normal('y',mu=mu,sigma=sigma,observed=data)
    
    #Need to manually compte the potential
    start = []
    for _ in range(n_tune_chains):
        mean = {var: val.copy() for var, val in m.test_point.items()}
        for val in mean.values():
            val[...] += 2 * np.random.rand(*val.shape) - 1
        start.append(mean)
    mean = np.mean([m.dict_to_array(vals) for vals in start], axis=0)
    var = np.ones_like(mean)
    potential = quadpotential.QuadPotentialDiagAdapt(m.ndim, mean, var, 10)
    
    step = pm.NUTS(potential=potential)
    
    ## 1D Tuning trace
    tune_trace = pm.sample(0, step=step, tune=5000, chains=n_tune_chains, discard_tuned_samples=False)

    ## Take step size from tuning trace
    step_size = tune_trace.get_sampler_stats('step_size_bar')[-1]
    step.tune = False
    step.step_adapt = step_sizes.DualAverageAdaptation(step_size, step.target_accept, 0.05, .75, 10)
    
    ## 5D Sampling trace (no tuning, fewer draws)
    sample_trace = pm.sample(draws=1000, step=step, tune=0, cores=n_sample_chains)

The above can probably be tidied a good bit, but seems to work at least. Can’t seem to upload images here but the posteriors are spot on.

2 Likes