`start` argument with multiple chains?

I have emulated Foreman-Mackey’s “dense mass matrices” approach for my rather high-dimensional and covariate multi-level problem, and I’ve run into a weird error that I’m having trouble parsing.

I sequentially build up a dense mass matrix in some shorter chains, and then I want to use it to make my actual chains run better.

def densemass_sample(model, nstart=200, nburn=200, ntune=5000, **kwargs):
    '''
    DFM's mass-matrix-learning routine

    https://dfm.io/posts/pymc3-mass-matrix/
    '''
    nwindow = nstart * 2 ** np.arange(np.floor(np.log2((ntune - nburn) / nstart)))
    nwindow = np.append(nwindow, ntune - nburn - np.sum(nwindow))
    nwindow = nwindow.astype('int')

    with model:
        start = None
        burnin_trace = None
        for steps in nwindow:
            step = get_step_for_trace(burnin_trace, regular_window=0)
            burnin_trace = pm.sample(
                start=start, tune=steps, draws=2, step=step,
                compute_convergence_checks=False, discard_tuned_samples=False,
                **kwargs)
            start = [t[-1] for t in burnin_trace._straces.values()]

        step = get_step_for_trace(burnin_trace, regular_window=0)

        return step, start

Then I try to run several more chains using the better mass matrix, but get a weird error:

 with model:
    nchains = 8
    test_trace = pymc3.sample(
        step=step, start=start,
        draws=1000, tune=500, burn=500, cores=1, chains=nchains)

Sequential sampling (8 chains in 1 job)
NUTS: [effQH, logQH, tauV 1mmu, tauV mu, age, logU, logZ, logZ-rad-sigma, logZ-r_rotated_, eta, ls-logZ]
100%|████████████████████████████████████████████████| 60/60 [02:08<00:00,  2.15s/it]

IndexError                                Traceback (most recent call last)
<ipython-input-23-6b8d8471f9bd> in <module>()
      3     test_trace = pymc3.sample(step=step, start=start,
      4                               draws=10, tune=50, burn=50,
----> 5                               cores=1, chains=nchains)
      6 

/usr/data/minhas/zpace/miniconda3/envs/MaNGA-metallicity/lib/python3.5/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
    467                 _log.info('Sequential sampling ({} chains in 1 job)'.format(chains))
    468                 _print_step_hierarchy(step)
--> 469                 trace = _sample_many(**sample_args)
    470 
    471         discard = tune if discard_tuned_samples else 0

/usr/data/minhas/zpace/miniconda3/envs/MaNGA-metallicity/lib/python3.5/site-packages/pymc3/sampling.py in _sample_many(draws, chain, chains, start, random_seed, step, **kwargs)
        512     traces = []
        513     for i in range(chains):
    --> 514         trace = _sample(draws=draws, chain=chain + i, start=start[i],
        515                         step=step, random_seed=random_seed[i], **kwargs)
        516         if trace is None:

IndexError: list index out of range

So sample finishes one chain fine, but then fails to start the next. Is this a problem with specifying a single start for multiple chains run sequentially?

With this in mind, I tried passing a list start = [start, ] * nchains, but got the following error: TypeError: start argument must be a dict or an array-like of dicts. Am I missing some other subtlety here?

(EDIT: fat-fingered submission, full question now present)

1 Like

I found a solution: since start is already a list, just use start * nchains in pymc3.sample.