Recently I noticed that I can’t sample certain models with several parallel chains, while others work ok. For a minimal example of what doesn’t work:
import numpy as np
import pymc3 as pm
n = 96
X = np.random.rand(n, 1)
observed = np.random.rand(n)
with pm.Model() as model:
cov = pm.gp.cov.ExpQuad(1, 1)
gp = pm.gp.Latent(mean_func=pm.gp.mean.Zero(), cov_func=cov)
gp_f = gp.prior('gp_f', X=X)
val_obs = pm.Normal('val_obs', mu=gp_f, sd=0.1, observed=observed)
trace = pm.sample(njobs=2)
No matter if run in a separate python script or from a notebook, the output is the following:
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [gp_f_rotated_]
0%| | 0/1000 [00:00<?, ?it/s]
Then it keeps indefinitely in this state, not loading the CPU at all. The traceback when interrupting the process after some waiting:
Traceback (most recent call last):
File "tmp.py", line 14, in <module>
trace = pm.sample(njobs=2)
File "*/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py", line 437, in sample
trace = _mp_sample(**sample_args)
File "*/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py", line 967, in _mp_sample
traces = Parallel(n_jobs=cores)(jobs)
File "*/anaconda3/lib/python3.6/site-packages/joblib/parallel.py", line 789, in __call__
self.retrieve()
File "*/anaconda3/lib/python3.6/site-packages/joblib/parallel.py", line 699, in retrieve
self._output.extend(job.get(timeout=self.timeout))
File "*/anaconda3/lib/python3.6/multiprocessing/pool.py", line 638, in get
self.wait(timeout)
File "*/anaconda3/lib/python3.6/multiprocessing/pool.py", line 635, in wait
self._event.wait(timeout)
File "*/anaconda3/lib/python3.6/threading.py", line 551, in wait
signaled = self._cond.wait(timeout)
File "*/anaconda3/lib/python3.6/threading.py", line 295, in wait
waiter.acquire()
KeyboardInterrupt
^C
This is really a minimal example: if I reduce n
from 96 to 95 sampling succeeds! Various other models also work without problems. And the example above works correctly if set njobs = 1
.
I can reproduce it on two Linux PCs, where everything python-related is installed in the same way using anaconda in a docker container.
Any ideas how to debug the issue?