I have emulated Foreman-Mackey’s “dense mass matrices” approach for my rather high-dimensional and covariate multi-level problem, and I’ve run into a weird error that I’m having trouble parsing.
I sequentially build up a dense mass matrix in some shorter chains, and then I want to use it to make my actual chains run better.
def densemass_sample(model, nstart=200, nburn=200, ntune=5000, **kwargs):
'''
DFM's mass-matrix-learning routine
https://dfm.io/posts/pymc3-mass-matrix/
'''
nwindow = nstart * 2 ** np.arange(np.floor(np.log2((ntune - nburn) / nstart)))
nwindow = np.append(nwindow, ntune - nburn - np.sum(nwindow))
nwindow = nwindow.astype('int')
with model:
start = None
burnin_trace = None
for steps in nwindow:
step = get_step_for_trace(burnin_trace, regular_window=0)
burnin_trace = pm.sample(
start=start, tune=steps, draws=2, step=step,
compute_convergence_checks=False, discard_tuned_samples=False,
**kwargs)
start = [t[-1] for t in burnin_trace._straces.values()]
step = get_step_for_trace(burnin_trace, regular_window=0)
return step, start
Then I try to run several more chains using the better mass matrix, but get a weird error:
with model:
nchains = 8
test_trace = pymc3.sample(
step=step, start=start,
draws=1000, tune=500, burn=500, cores=1, chains=nchains)
Sequential sampling (8 chains in 1 job)
NUTS: [effQH, logQH, tauV 1mmu, tauV mu, age, logU, logZ, logZ-rad-sigma, logZ-r_rotated_, eta, ls-logZ]
100%|████████████████████████████████████████████████| 60/60 [02:08<00:00, 2.15s/it]
IndexError Traceback (most recent call last)
<ipython-input-23-6b8d8471f9bd> in <module>()
3 test_trace = pymc3.sample(step=step, start=start,
4 draws=10, tune=50, burn=50,
----> 5 cores=1, chains=nchains)
6
/usr/data/minhas/zpace/miniconda3/envs/MaNGA-metallicity/lib/python3.5/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
467 _log.info('Sequential sampling ({} chains in 1 job)'.format(chains))
468 _print_step_hierarchy(step)
--> 469 trace = _sample_many(**sample_args)
470
471 discard = tune if discard_tuned_samples else 0
/usr/data/minhas/zpace/miniconda3/envs/MaNGA-metallicity/lib/python3.5/site-packages/pymc3/sampling.py in _sample_many(draws, chain, chains, start, random_seed, step, **kwargs)
512 traces = []
513 for i in range(chains):
--> 514 trace = _sample(draws=draws, chain=chain + i, start=start[i],
515 step=step, random_seed=random_seed[i], **kwargs)
516 if trace is None:
IndexError: list index out of range
So sample
finishes one chain fine, but then fails to start the next. Is this a problem with specifying a single start
for multiple chains run sequentially?
With this in mind, I tried passing a list start = [start, ] * nchains
, but got the following error: TypeError: start argument must be a dict or an array-like of dicts
. Am I missing some other subtlety here?
(EDIT: fat-fingered submission, full question now present)