Recently I noticed that I can’t sample certain models with several parallel chains, while others work ok. For a minimal example of what doesn’t work:
import numpy as np import pymc3 as pm n = 96 X = np.random.rand(n, 1) observed = np.random.rand(n) with pm.Model() as model: cov = pm.gp.cov.ExpQuad(1, 1) gp = pm.gp.Latent(mean_func=pm.gp.mean.Zero(), cov_func=cov) gp_f = gp.prior('gp_f', X=X) val_obs = pm.Normal('val_obs', mu=gp_f, sd=0.1, observed=observed) trace = pm.sample(njobs=2)
No matter if run in a separate python script or from a notebook, the output is the following:
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [gp_f_rotated_] 0%| | 0/1000 [00:00<?, ?it/s]
Then it keeps indefinitely in this state, not loading the CPU at all. The traceback when interrupting the process after some waiting:
Traceback (most recent call last): File "tmp.py", line 14, in <module> trace = pm.sample(njobs=2) File "*/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py", line 437, in sample trace = _mp_sample(**sample_args) File "*/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py", line 967, in _mp_sample traces = Parallel(n_jobs=cores)(jobs) File "*/anaconda3/lib/python3.6/site-packages/joblib/parallel.py", line 789, in __call__ self.retrieve() File "*/anaconda3/lib/python3.6/site-packages/joblib/parallel.py", line 699, in retrieve self._output.extend(job.get(timeout=self.timeout)) File "*/anaconda3/lib/python3.6/multiprocessing/pool.py", line 638, in get self.wait(timeout) File "*/anaconda3/lib/python3.6/multiprocessing/pool.py", line 635, in wait self._event.wait(timeout) File "*/anaconda3/lib/python3.6/threading.py", line 551, in wait signaled = self._cond.wait(timeout) File "*/anaconda3/lib/python3.6/threading.py", line 295, in wait waiter.acquire() KeyboardInterrupt ^C
This is really a minimal example: if I reduce
n from 96 to 95 sampling succeeds! Various other models also work without problems. And the example above works correctly if set
njobs = 1.
I can reproduce it on two Linux PCs, where everything python-related is installed in the same way using anaconda in a docker container.
Any ideas how to debug the issue?