Sampling doesn't start when njobs > 1 for some models


#1

Recently I noticed that I can’t sample certain models with several parallel chains, while others work ok. For a minimal example of what doesn’t work:

import numpy as np
import pymc3 as pm

n = 96
X = np.random.rand(n, 1)
observed = np.random.rand(n)

with pm.Model() as model:
    cov = pm.gp.cov.ExpQuad(1, 1)
    gp = pm.gp.Latent(mean_func=pm.gp.mean.Zero(), cov_func=cov)
    gp_f = gp.prior('gp_f', X=X)
    val_obs = pm.Normal('val_obs', mu=gp_f, sd=0.1, observed=observed)

    trace = pm.sample(njobs=2)

No matter if run in a separate python script or from a notebook, the output is the following:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [gp_f_rotated_]
  0%|                                                                              | 0/1000 [00:00<?, ?it/s]

Then it keeps indefinitely in this state, not loading the CPU at all. The traceback when interrupting the process after some waiting:

Traceback (most recent call last):
  File "tmp.py", line 14, in <module>
    trace = pm.sample(njobs=2)
  File "*/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py", line 437, in sample
    trace = _mp_sample(**sample_args)
  File "*/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py", line 967, in _mp_sample
    traces = Parallel(n_jobs=cores)(jobs)
  File "*/anaconda3/lib/python3.6/site-packages/joblib/parallel.py", line 789, in __call__
    self.retrieve()
  File "*/anaconda3/lib/python3.6/site-packages/joblib/parallel.py", line 699, in retrieve
    self._output.extend(job.get(timeout=self.timeout))
  File "*/anaconda3/lib/python3.6/multiprocessing/pool.py", line 638, in get
    self.wait(timeout)
  File "*/anaconda3/lib/python3.6/multiprocessing/pool.py", line 635, in wait
    self._event.wait(timeout)
  File "*/anaconda3/lib/python3.6/threading.py", line 551, in wait
    signaled = self._cond.wait(timeout)
  File "*/anaconda3/lib/python3.6/threading.py", line 295, in wait
    waiter.acquire()
KeyboardInterrupt
^C

This is really a minimal example: if I reduce n from 96 to 95 sampling succeeds! Various other models also work without problems. And the example above works correctly if set njobs = 1.

I can reproduce it on two Linux PCs, where everything python-related is installed in the same way using anaconda in a docker container.
Any ideas how to debug the issue?


#2

This is likely a memory problem. I see it all the time on my mac… Setting njobs=1 helps because either the model can run, and if it fail now it wont fail silently.


#3

By memory problem you mean low memory, right? That’s surprising because the machines have 64 gigs of ram, and anyway, much larger models including GPs can be sampled (using njobs = 1).


#4

hmm, in that case maybe it is a joblib problem, I remember seeing one in issues: https://github.com/pymc-devs/pymc3/issues/2640 not sure if it is related though.


#5

It’s very difficult to guess what’s the problem as it doesn’t raise any exception - just stalls indefinitely. Maybe there is something like “verbose” mode in pymc, or similar?


#6

There is a theano verbose mode, but in this case the problem is likely joblib related. Maybe you can try turning on the verbose in joblib https://pythonhosted.org/joblib/parallel.html


#7

Well, the joblib verbose mode doesn’t seem to help at all: it prints many messages of the form Pickling array (shape=(*,), dtype=float32)., and still no errors while sampling stalls.


#8

I am curious if there is any update on this post? I have similar problem.


#9

Yeah, I also still experience this issue, and not only for the model mentioned in the first post. Entirely different classes of models don’t sample in multiple processes, so that’s not related to gaussian processes as I thought at first.


#10

I suspect it is something to do with joblib pickling - unfortunately cannot really pin down where went wrong yet.