Harnessing multiple cores to speed up fits with small number of chains

The code below is not there yet, but perhaps it might be a first step toward a solution:

import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt

data = np.random.normal(0,1,1000)

with pm.Model() as model:
    
    mu = pm.Normal('mu',0,1)
    sigma = pm.HalfNormal('sigma',1)
    
    y = pm.Normal('y',mu=mu,sigma=sigma,observed=data)
    
    tune_trace = pm.sample(0,tune=1000,chains=1,discard_tuned_samples=False)
    new_step=pm.NUTS()
    new_step.step_size=tune_trace.get_sampler_stats("step_size_bar")[-1]
    new_step.tune=False
    
    trace = pm.sample(1000,step=new_step)

The above will run without an error, however it’s not applying the step_size manually inputted into new_step:

In: new_step.step_size
Out: 1.2404317754404846

In: trace.step_size
Out: [0.2102241, 0.2102241, 0.2102241, ..., 0.2102241, 0.2102241,
       0.2102241]

However, pm.sample is picking up that new_step is configured not to tune (new_step.tune=False) it just doesn’t seem to be implementing the step_size… perhaps a few more parameters need to be specified?