Sampling doesn't start when njobs > 1 for some models


#1

Recently I noticed that I can’t sample certain models with several parallel chains, while others work ok. For a minimal example of what doesn’t work:

import numpy as np
import pymc3 as pm

n = 96
X = np.random.rand(n, 1)
observed = np.random.rand(n)

with pm.Model() as model:
    cov = pm.gp.cov.ExpQuad(1, 1)
    gp = pm.gp.Latent(mean_func=pm.gp.mean.Zero(), cov_func=cov)
    gp_f = gp.prior('gp_f', X=X)
    val_obs = pm.Normal('val_obs', mu=gp_f, sd=0.1, observed=observed)

    trace = pm.sample(njobs=2)

No matter if run in a separate python script or from a notebook, the output is the following:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [gp_f_rotated_]
  0%|                                                                              | 0/1000 [00:00<?, ?it/s]

Then it keeps indefinitely in this state, not loading the CPU at all. The traceback when interrupting the process after some waiting:

Traceback (most recent call last):
  File "tmp.py", line 14, in <module>
    trace = pm.sample(njobs=2)
  File "*/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py", line 437, in sample
    trace = _mp_sample(**sample_args)
  File "*/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py", line 967, in _mp_sample
    traces = Parallel(n_jobs=cores)(jobs)
  File "*/anaconda3/lib/python3.6/site-packages/joblib/parallel.py", line 789, in __call__
    self.retrieve()
  File "*/anaconda3/lib/python3.6/site-packages/joblib/parallel.py", line 699, in retrieve
    self._output.extend(job.get(timeout=self.timeout))
  File "*/anaconda3/lib/python3.6/multiprocessing/pool.py", line 638, in get
    self.wait(timeout)
  File "*/anaconda3/lib/python3.6/multiprocessing/pool.py", line 635, in wait
    self._event.wait(timeout)
  File "*/anaconda3/lib/python3.6/threading.py", line 551, in wait
    signaled = self._cond.wait(timeout)
  File "*/anaconda3/lib/python3.6/threading.py", line 295, in wait
    waiter.acquire()
KeyboardInterrupt
^C

This is really a minimal example: if I reduce n from 96 to 95 sampling succeeds! Various other models also work without problems. And the example above works correctly if set njobs = 1.

I can reproduce it on two Linux PCs, where everything python-related is installed in the same way using anaconda in a docker container.
Any ideas how to debug the issue?


#2

This is likely a memory problem. I see it all the time on my mac… Setting njobs=1 helps because either the model can run, and if it fail now it wont fail silently.


#3

By memory problem you mean low memory, right? That’s surprising because the machines have 64 gigs of ram, and anyway, much larger models including GPs can be sampled (using njobs = 1).


#4

hmm, in that case maybe it is a joblib problem, I remember seeing one in issues: https://github.com/pymc-devs/pymc3/issues/2640 not sure if it is related though.


#5

It’s very difficult to guess what’s the problem as it doesn’t raise any exception - just stalls indefinitely. Maybe there is something like “verbose” mode in pymc, or similar?


#6

There is a theano verbose mode, but in this case the problem is likely joblib related. Maybe you can try turning on the verbose in joblib https://pythonhosted.org/joblib/parallel.html


#7

Well, the joblib verbose mode doesn’t seem to help at all: it prints many messages of the form Pickling array (shape=(*,), dtype=float32)., and still no errors while sampling stalls.