Pm.sample gets stuck after init with cores > 1

Hi, novice pymc user here. I’ve built a linear regression model with ca. 90 predictors (features) and 3950 samples. The obs_data contain (what I believe are) sensible, standardised values (no NaNs or Infs). Here’s my (pseudo)code:

    obs_data = pd.DataFrame(...) # size (3950, 91)
    predictors = obs_data.columns
    glm_formula = 'target ~ predictor1 + predictor2 + ....'  # 

    with pm.Model() as model:
        priors = {}
        for predictor in colnames:
            priors[predictor] = pm.Normal.dist(0., 25.)
        priors['Intercept'] = pm.Normal.dist(0, 25.)
        
        pm.GLM.from_formula(
            glm_formula,
            obs_data,
            priors=priors)
        
        prior_predictive = pm.sample_prior_predictive()
        trace = pm.sample(draws=2000, chains=2, cores=1, tune=2000)
        posterior_predictive = pm.sample_posterior_predictive(trace)

If I set cores > 1 in pm.sample then the sampling doesn’t start. With a single core, the sampling works (albeit slowly). I’m running on a 8-core, 32GB RAM (virtual) machine, so memory and/or CPU cores shouldn’t be an issue. I read in an older post that this may be caused by a bug in joblib? Any ideas?

OS: Linux x86_64
pymc version: 3.8
joblib version: 0.14.1
environment: Jupyter notebook

Later EDIT: same problem if I run the code as a command line script.

Thank you!

Hi,
What’s the error you’re getting back when cores > 1 ?

Hi there. There’s no error. The sampler just gets stuck, with the progressbar reporting 0 / 10000 samples.

That’s weird :thinking: Are you using seabord or matplotlib before running your model? Could be related to this error – even though you’re on Linux.

Are you using seaborn or matplotlib before running your model?

Nope. I’ve got arviz installed in my conda environment but I’m not using it for this example.

Have you tried specifiying your priors a bit more? Maybe do pm.Normal(0,5) for predictors/ intercepts?

Have you tried specifiying your priors a bit more? Maybe do pm.Normal(0,5) for predictors/ intercepts?

Yup. I’ve tried more informative priors like pm.Normal.dist(0., 1.) and couldn’t get the sampler to start.

Yeah, I think your priors are not informative enough, and I suspect that’s one reason why sampling is slow when cores=1. But if this were the cause of the error, you should get an error message when cores>1.

This looks a lot like a multiprocessing break issue :confused:
@lucianopaz, I know multiprocessing breaks on Windows and we’ve had some examples with recent versions of Python and MacOS, but is it also the case on Linux now? :fearful:

@AlexAndorra @lucianopaz

So I’ve tried reducing the no. of predictors to 9 (from 90) and the multi-core (cores=2) sampling started to work. Strange.

If I bump my predictor count up to 73 (cores=2), the sampler gets stuck again. Here’s the stack trace that I get when I ctrl-c the python process:

  File ".conda/envs/DEV/lib/python3.7/site-packages/pymc3/sampling.py", line 1059, in _mp_sample
    for draw in sampler:
  File ".conda/envs/DEV/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 394, in __iter__
    draw = ProcessAdapter.recv_draw(self._active)
  File ".conda/envs/DEV/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 284, in recv_draw
    ready = multiprocessing.connection.wait(pipes)
  File ".conda/envs/DEV/lib/python3.7/multiprocessing/connection.py", line 920, in wait
    ready = selector.select(timeout)
  File ".conda/envs/DEV/lib/python3.7/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File ".conda/envs/DEV/lib/python3.7/site-packages/pymc3/sampling.py", line 1068, in _mp_sample
    trace._add_warnings(draw.warnings)
  File ".conda/envs/DEV/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 425, in __exit__
    ProcessAdapter.terminate_all(self._samplers)
  File ".conda/envs/DEV/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 319, in terminate_all
    process.join(timeout)
  File ".conda/envs/DEV/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 274, in join
    self._process.join(timeout)
  File ".conda/envs/DEV/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File ".conda/envs/DEV/lib/python3.7/multiprocessing/popen_fork.py", line 45, in wait
    if not wait([self.sentinel], timeout):
  File ".conda/envs/DEV/lib/python3.7/multiprocessing/connection.py", line 920, in wait
    ready = selector.select(timeout)
  File ".conda/envs/DEV/lib/python3.7/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test.py", line 85, in <module>
    trace = pm.sample(draws=3000, chains=4, cores=2, tune=2000)
  File ".conda/envs/DEV/lib/python3.7/site-packages/pymc3/sampling.py", line 469, in sample
    trace = _mp_sample(**sample_args)
  File ".conda/envs/DEV/lib/python3.7/site-packages/pymc3/sampling.py", line 1080, in _mp_sample
    traces, length = _choose_chains(traces, tune)
  File ".conda/envs/DEV/lib/python3.7/site-packages/pymc3/sampling.py", line 1096, in _choose_chains
    raise ValueError("Not enough samples to build a trace.")
ValueError: Not enough samples to build a trace.
Sampling 4 chains, 0 divergences:   0%|                                                                                                              | 0/20000 [01:29<?, ?draws/s]

Any useful info in here? Seems like the parallel sampler is waiting on something - but what?

It’s not the same issue as what we’ve seen in Mac and windows. If it were, @mishooax should be seeing broken pipe errors instead of not seeing the sampling begin.
Linux still defaults to forking new processes with the multiprocessing module, whereas the issue in the other OS’es happens because they spawn new processes instead.
The most likely reason that the sampling doesn’t start has to do with memory usage. When sampling is performed on multiple cores, the entire model and its observations are copied to every new process. With models that are slightly big and have a large number of observations, the process of copying may take a long time or may even crash.

1 Like

models that are slightly big and have a large number of observations

@lucianopaz Would my linear model with 70-odd predictors and O(1000) observations be too large to copy w/o crashing? That’s still a small model in my book. I was hoping to scale this to something like 200 predictors and O(10^5) observations. I’m running this as a batch job on a supercomputer so the core count and/or mem footprint are not a limitation.

The traceback you are seeing there means that the forked worker processes didn’t sample any points before being terminated by the keyboard interruption. This simply reflects that the worker processes hadn’t started to draw samples because they hadn’t finished copying all of the data around.

The traceback you are seeing there means that the forked worker processes didn’t sample any points before being terminated by the keyboard interruption. This simply reflects that the worker processes hadn’t started to draw samples because they hadn’t finished copying all of the data around.

Gotcha. However, with my (90-predictor, 1000-obs) model the worker processes still have not sampled anything more than 1/2 hour after calling pm.sample. Should I give up on using multiple cores when sampling?

I would say that you should be able to sample from a model like that, though it will take time and some restructuring to handle the parallelization properly. Maybe @twiecki can give some advice on the subject. I’ve dealt with parallelism by ditching pythons multiprocessing module and running several sample jobs concurrently with different initialization seeds.

You mean something like:

trace_1 = pm.sample(draws=2000, chains=1, cores=1, tune=2000, random_seed=SEED1)
trace_2 = pm.sample(draws=2000, chains=1, cores=1, tune=2000, random_seed=SEED2)

and run these in parallel somehow? (how?)

or two completely separate python processes, each sampling one MCMC chain + post-processing (combine the chains) when both are done? I guess this setup can be parallelized with something like mpi4py

Almost what you said. When I did this, I was running on an sge managed cluster, so I submitted a job array where a called my python script and passed the job’s id as a command line argument: python my_script.py $job_id
Then, in my python script I grabbed the job id from sys.argvand used that to set the random_seed I passed to pm.sample. This approach will deploy as many independent processes as chains, and they will run in parallel, managed by the cluster’s manager or by your computers operating system.
You would also have to write a separate script for post processing, where you merge the multiple chains together and use them for whatever it is you actually want.

2 Likes

FWIW, having the exact same issue on MacOS.

Defined my logistic regression model and it works with cores = 1, but as soon as I try to increase the cores it just never gets started. Here’s where it hangs:

Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (2 chains in 4 jobs)
NUTS: [X1, X2, X3, X4, Intercept]
Sampling 2 chains, 0 divergences:   0%|          | 0/4000 [00:00<?, ?draws/s]

And this is the stack trace when I stop it after > 30 mins:

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
~/opt/anaconda3/lib/python3.7/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, **kwargs)
   1058             with sampler:
-> 1059                 for draw in sampler:
   1060                     trace = traces[draw.chain - chain]

~/opt/anaconda3/lib/python3.7/site-packages/pymc3/parallel_sampling.py in __iter__(self)
    393         while self._active:
--> 394             draw = ProcessAdapter.recv_draw(self._active)
    395             proc, is_last, draw, tuning, stats, warns = draw

~/opt/anaconda3/lib/python3.7/site-packages/pymc3/parallel_sampling.py in recv_draw(processes, timeout)
    283         pipes = [proc._msg_pipe for proc in processes]
--> 284         ready = multiprocessing.connection.wait(pipes)
    285         if not ready:

~/opt/anaconda3/lib/python3.7/multiprocessing/connection.py in wait(object_list, timeout)
    919             while True:
--> 920                 ready = selector.select(timeout)
    921                 if ready:

~/opt/anaconda3/lib/python3.7/selectors.py in select(self, timeout)
    414         try:
--> 415             fd_event_list = self._selector.poll(timeout)
    416         except InterruptedError:

KeyboardInterrupt: 

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-126-91d098b82ba5> in <module>
      1 with logistic_model:
----> 2     pred_trace = pm.sample(draws=1000, tune=1000, chains=2, cores=4, init='adapt_diag')

~/opt/anaconda3/lib/python3.7/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, **kwargs)
    467         _print_step_hierarchy(step)
    468         try:
--> 469             trace = _mp_sample(**sample_args)
    470         except pickle.PickleError:
    471             _log.warning("Could not pickle model, sampling singlethreaded.")

~/opt/anaconda3/lib/python3.7/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, **kwargs)
   1078         return MultiTrace(traces)
   1079     except KeyboardInterrupt:
-> 1080         traces, length = _choose_chains(traces, tune)
   1081         return MultiTrace(traces)[:length]
   1082     finally:

~/opt/anaconda3/lib/python3.7/site-packages/pymc3/sampling.py in _choose_chains(traces, tune)
   1094     lengths = [max(0, len(trace) - tune) for trace in traces]
   1095     if not sum(lengths):
-> 1096         raise ValueError("Not enough samples to build a trace.")
   1097 
   1098     idxs = np.argsort(lengths)[::-1]

ValueError: Not enough samples to build a trace.

I seem to have solved this on Linux by adding the following to the very top of my notebook, before other imports:

import multiprocessing as mp
mp.set_start_method('forkserver')
2 Likes