Harnessing multiple cores to speed up fits with small number of chains

Is there a recommended way to harness the benefits of server nodes with a high number of cores, to speed up model fits in pymc3?

In my specific scenario, I ideally want a a small number (~4) of fairly long (~100k samples) chains, so specifying a high number of cores in pm.sample is not appropriate.

One option is to run a high number of short single-chain fits, specifying a large number of cores in pm.sample, and amalgamate these chains later. e.g. pm.sample(1000,…,chains=1,cores=400,…)

However given that each of these mini chains would require tuning, this seems a bit redundant…

Is it appropriate/possible to save a sampler state after tuning and use this with multiple small single-chain fits instead?

e.g., to do a pseudo 4-chain, 100k samples per chain fit:

run 4 tuning traces with no actual samples, e.g.

tune_trace_1 = pm.sample(0,tune=50000,chains=1)
tune_trace_2 = pm.sample(0,tune=50000,chains=1)
tune_trace_3 = pm.sample(0,tune=50000,chains=1)
tune_trace_4 = pm.sample(0,tune=50000,chains=1)

then save/load the states to apply to:

sample_trace_1 = pm.sample(1000,tune=<settings from tune_trace_1>,cores=100)
sample_trace_2 = pm.sample(1000,tune=<settings from tune_trace_2>,cores=100)
sample_trace_3 = pm.sample(1000,tune=<settings from tune_trace_3>,cores=100)
sample_trace_4 = pm.sample(1000,tune=<settings from tune_trace_4>,cores=100)

You then amalgamate the 100 chains from each trace into 100k arrays, run your ess, rhat etc. And then ultimately be left with 400k sample parameter posteriors?

No answers here, but I would also be interested in knowing how to do things along these lines. It seems like the sticking point is getting access to the tuned sampler. You can get the initialized sampler from init_nuts and you can pass that to sample() (or iter_sample()). But once the sampler is sampling (e.g., tuning), it seems like the only thing you can easily access to is the resulting trace (not current state of the sampler/step methods).

@cluhmann yeah, as far as I can see, storing the results from appropriately-sized tuning (given your model complexity), and then re-using the sampler settings across a large number of small chains in parallel is the only way to satisfy the trade off between tuning requirements and redundancy.

Perhaps easier said than done, but it would be great if this could become a pm.sample option. Something like a “tune_sharing” parameter: By default the sampler will tune as many times as the number of specified chains, however this parameter could then allow you to share the tuning across more than one chain.

Something like:
pm.sample(1000,tune=50000,tune_sharing=4,cores=400)
Would create 400 chains of 1000 samples, but only 4 samplers trained on 50000 tuning steps, which are then shared evenly across the 400 chains, i.e., 100 per sample state.

It seems like the most straightforward way to handle this would be for the PPL to to allow you to carry around your sampler/step method, initialize it, tune it, draw samples from it etc. But that doesn’t really conform to the current pymc API nor it’s design and reliance on the model context. But it may be possible by accessing non-obvious parts of the API.

The code below is not there yet, but perhaps it might be a first step toward a solution:

import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt

data = np.random.normal(0,1,1000)

with pm.Model() as model:
    
    mu = pm.Normal('mu',0,1)
    sigma = pm.HalfNormal('sigma',1)
    
    y = pm.Normal('y',mu=mu,sigma=sigma,observed=data)
    
    tune_trace = pm.sample(0,tune=1000,chains=1,discard_tuned_samples=False)
    new_step=pm.NUTS()
    new_step.step_size=tune_trace.get_sampler_stats("step_size_bar")[-1]
    new_step.tune=False
    
    trace = pm.sample(1000,step=new_step)

The above will run without an error, however it’s not applying the step_size manually inputted into new_step:

In: new_step.step_size
Out: 1.2404317754404846

In: trace.step_size
Out: [0.2102241, 0.2102241, 0.2102241, ..., 0.2102241, 0.2102241,
       0.2102241]

However, pm.sample is picking up that new_step is configured not to tune (new_step.tune=False) it just doesn’t seem to be implementing the step_size… perhaps a few more parameters need to be specified?

I think you’re on the right track. When you call pm.sample() with new_step, I think it tunes by default. So maybe try

pm.sample(1000,step=new_step, tune=0)

and see if the step size is intact. Then I think you may want to extract a point from tune_trace and feed it in as a start point to new_step. That should (I think?) prevent sample() from trying to initialize. And then I would try to make sure that whatever other sampler parameters are optimized during tuning are also being copied over from tune_trace to new_step. I suspect that the list is longer than step size (tree size? max energy?).

I found this very helpful thread from a few years ago

So I think we have a winner with the following (credit to @junpenglao):

import pymc3 as pm
import numpy as np
from pymc3.step_methods import step_sizes
from pymc3.step_methods.hmc import quadpotential

import warnings
warnings.filterwarnings("ignore");

data = np.random.normal(0,1,1000)

n_tune_chains = 1
n_sample_chains = 5

with pm.Model() as m:
    
    mu = pm.Normal('mu',0,1)
    sigma = pm.HalfNormal('sigma',1)
    y = pm.Normal('y',mu=mu,sigma=sigma,observed=data)
    
    #Need to manually compte the potential
    start = []
    for _ in range(n_tune_chains):
        mean = {var: val.copy() for var, val in m.test_point.items()}
        for val in mean.values():
            val[...] += 2 * np.random.rand(*val.shape) - 1
        start.append(mean)
    mean = np.mean([m.dict_to_array(vals) for vals in start], axis=0)
    var = np.ones_like(mean)
    potential = quadpotential.QuadPotentialDiagAdapt(m.ndim, mean, var, 10)
    
    step = pm.NUTS(potential=potential)
    
    ## 1D Tuning trace
    tune_trace = pm.sample(0, step=step, tune=5000, chains=n_tune_chains, discard_tuned_samples=False)

    ## Take step size from tuning trace
    step_size = tune_trace.get_sampler_stats('step_size_bar')[-1]
    step.tune = False
    step.step_adapt = step_sizes.DualAverageAdaptation(step_size, step.target_accept, 0.05, .75, 10)
    
    ## 5D Sampling trace (no tuning, fewer draws)
    sample_trace = pm.sample(draws=1000, step=step, tune=0, cores=n_sample_chains)

The above can probably be tidied a good bit, but seems to work at least. Can’t seem to upload images here but the posteriors are spot on.

2 Likes

Just curious - assuming you pass the same tuning data to each chain, is many short chains equivalent to fewer long chains? Or are there other tangible benefits to running a large number of chains?

I’m wondering if it’s worth the effort to try and parallelize over a spark cluster.

To add on to the previous question, are you interested in the parameter space position of the tuned chain, the learned covariance matrix, or both?

@JaredStufft my hope is that they would indeed be equivalent, though I suspect autocorrelation in your posterior chain will be the killer here.

@ckrapu I suppose it’s a matter of establishing step size, first and foremost? So the cov matrix would be the more important?