Seaborn import causes parallel sampling to freeze

I’m not sure if this is the best place to report a potential bug in PyMC3 - But the following problem caused some headache to me during the last days until I could isolate the root cause:

I have a model setup which gets stuck during sampling (after initialization, before drawing the first sample) if more than 1 core is used and if Seaborn is imported and if more than 31 datapoints are used to fit a gaussian process.

This code is a minimal example to reproduce the issue:

    import numpy as np
    import pymc3 as pm
    import matplotlib.pyplot as plt
    import seaborn as sns # without this import it's working

    n_data = 32 # n_data = 31 is working
    x_data = np.linspace(0,10,n_data)
    y_data = np.random.normal(size=n_data)
    with pm.Model() as model:
        eta = pm.HalfCauchy('eta',beta=0.5)
        ls = pm.Gamma('ls',mu=1 ,sigma=0.7)
        sigma = pm.HalfCauchy('sigma',beta=1)
        cov = eta**2 * pm.gp.cov.Matern52(1,ls=ls)
        gp = pm.gp.Latent(cov_func=cov)
        f = gp.prior('f',x_data[:,None])
        y = pm.Normal('y',mu=f,sigma=sigma,observed=y_data)
        trace = pm.sample(cores=8) # cores=1 is also working
    pm.traceplot(trace,compact=True)
    plt.show()
    pm.gp.util.plot_gp_dist(plt.gca(),trace['f'],x_data)
    plt.plot(x_data,y_data)
    plt.show()

The sampling is freezing at this point:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (8 chains in 8 jobs)
NUTS: [f_rotated_, sigma, ls, eta]

0.00% [0/16000 00:00<00:00 Sampling 8 chains, 0 divergences]

After hitting Ctlr+C I get the following traceback:

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/opt/conda/lib/python3.8/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
   1485             with sampler:
-> 1486                 for draw in sampler:
   1487                     trace = traces[draw.chain - chain]

/opt/conda/lib/python3.8/site-packages/pymc3/parallel_sampling.py in __iter__(self)
    491         while self._active:
--> 492             draw = ProcessAdapter.recv_draw(self._active)
    493             proc, is_last, draw, tuning, stats, warns = draw

/opt/conda/lib/python3.8/site-packages/pymc3/parallel_sampling.py in recv_draw(processes, timeout)
    351         pipes = [proc._msg_pipe for proc in processes]
--> 352         ready = multiprocessing.connection.wait(pipes)
    353         if not ready:

/opt/conda/lib/python3.8/multiprocessing/connection.py in wait(object_list, timeout)
    930             while True:
--> 931                 ready = selector.select(timeout)
    932                 if ready:

/opt/conda/lib/python3.8/selectors.py in select(self, timeout)
    414         try:
--> 415             fd_event_list = self._selector.poll(timeout)
    416         except InterruptedError:

KeyboardInterrupt: 

During handling of the above exception, another exception occurred:

KeyboardInterrupt                         Traceback (most recent call last)
/opt/conda/lib/python3.8/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
   1497                     if callback is not None:
-> 1498                         callback(trace=trace, draw=draw)
   1499 

/opt/conda/lib/python3.8/site-packages/pymc3/parallel_sampling.py in __exit__(self, *args)
    524     def __exit__(self, *args):
--> 525         ProcessAdapter.terminate_all(self._samplers)
    526 

/opt/conda/lib/python3.8/site-packages/pymc3/parallel_sampling.py in terminate_all(processes, patience)
    386                     raise multiprocessing.TimeoutError()
--> 387                 process.join(timeout)
    388         except multiprocessing.TimeoutError:

/opt/conda/lib/python3.8/site-packages/pymc3/parallel_sampling.py in join(self, timeout)
    341     def join(self, timeout=None):
--> 342         self._process.join(timeout)
    343 

/opt/conda/lib/python3.8/multiprocessing/process.py in join(self, timeout)
    148         assert self._popen is not None, 'can only join a started process'
--> 149         res = self._popen.wait(timeout)
    150         if res is not None:

/opt/conda/lib/python3.8/multiprocessing/popen_fork.py in wait(self, timeout)
     43                 from multiprocessing.connection import wait
---> 44                 if not wait([self.sentinel], timeout):
     45                     return None

/opt/conda/lib/python3.8/multiprocessing/connection.py in wait(object_list, timeout)
    930             while True:
--> 931                 ready = selector.select(timeout)
    932                 if ready:

/opt/conda/lib/python3.8/selectors.py in select(self, timeout)
    414         try:
--> 415             fd_event_list = self._selector.poll(timeout)
    416         except InterruptedError:

KeyboardInterrupt: 

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-2-4cd3f9bd1287> in <module>
     15     f = gp.prior('f',x_data[:,None])
     16     y = pm.Normal('y',mu=f,sigma=sigma,observed=y_data)
---> 17     trace = pm.sample(cores=8) # cores=1 is also working
     18 pm.traceplot(trace,compact=True)
     19 plt.show()

/opt/conda/lib/python3.8/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    543         _print_step_hierarchy(step)
    544         try:
--> 545             trace = _mp_sample(**sample_args, **parallel_args)
    546         except pickle.PickleError:
    547             _log.warning("Could not pickle model, sampling singlethreaded.")

/opt/conda/lib/python3.8/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
   1510     except KeyboardInterrupt:
   1511         if discard_tuned_samples:
-> 1512             traces, length = _choose_chains(traces, tune)
   1513         else:
   1514             traces, length = _choose_chains(traces, 0)

/opt/conda/lib/python3.8/site-packages/pymc3/sampling.py in _choose_chains(traces, tune)
   1528     lengths = [max(0, len(trace) - tune) for trace in traces]
   1529     if not sum(lengths):
-> 1530         raise ValueError("Not enough samples to build a trace.")
   1531 
   1532     idxs = np.argsort(lengths)[::-1]

ValueError: Not enough samples to build a trace.

For me I solved this issue by removing the Seaborn import, as I actually don’t need it anymore in this context. However, as many PyMC3 user are also using Seaborn, I think it might be a good idea to fix this bug somehow.

Version Information:

  • pymc3 3.9.3
  • theano 1.0.5
  • seaborn 0.10.1
  • conda 4.8.3
  • Python 3.8.5
  • Ubuntu 20.04 LTS
1 Like

Thanks for reporting and tracking this down to this point. It does indeed sound like a bug (although maybe not a pymc3 bug).
Could you open an issue on the pymc3 github, and also include what CPU (amd or intel mostly) and which blas implementation your are using?
If you are using conda-forge you can check blas with conda list libblas, otherwise the output of np.__config__.show() should hopefully include that.

1 Like

Unfortunately I don’t have a GitHub account and am currently struggling to create one (they are practically rejecting all my passwords…). So I would be happy if maybe someone else could open an issue and link this thread. :wink:

So this is my output of conda list libblas:

# packages in environment at /opt/conda:
#
# Name                    Version                   Build  Channel
libblas                   3.8.0               14_openblas    conda-forge

And this one from lscpu:

Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   39 bits physical, 48 bits virtual
CPU(s):                          8
On-line CPU(s) list:             0-7
Thread(s) per core:              2
Core(s) per socket:              4
Socket(s):                       1
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           142
Model name:                      Intel(R) Core(TM) i7-8665U CPU @ 1.90GHz
Stepping:                        12
CPU MHz:                         2112.008
BogoMIPS:                        4224.01
Hypervisor vendor:               Microsoft
Virtualization type:             full
L1d cache:                       128 KiB
L1i cache:                       128 KiB
L2 cache:                        1 MiB
L3 cache:                        8 MiB
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_ts
                                 c rep_good nopl xtopology cpuid pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3d
                                 nowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx rdseed adx smap clflushopt x
                                 saveopt xsavec xgetbv1 xsaves flush_l1d arch_capabilities

However, during my debugging I also tested it on another machine with an AMD processor and the bug was still present.

Further information: The bug is present if you execute the code in a Jupyter notebook, but also as a plain script.

I will try to reproduce this tomorrow, when I have proper internet again (train).

We can discuss this here, no problem. I still made a github issue for future reference: Sampling freezes with openblas and multiprocessing · Issue #4092 · pymc-devs/pymc · GitHub

they are practically rejecting all my passwords

I’d still like to suggest a password manager though. :wink:

This could be a fork safety issue in openblas. If so, then any of those things should make the problem go away:

  • Set the number of threads for blas to 1 (execute import os; os.environ['OPENBLAS_NUM_THREADS'] = '1' first thing after restarting the notebook)
  • Use spawn or forkserver as multiprocessing backend: set mp_ctx='spawn' or pm_ctx="forkserver" in pm.sample.
  • Switch to a different blas implementation: conda install "libblas=*=*mkl in the conda environment you are using as kernel (if you want to keep this (should be faster most of the time anyway, you might want to add that as pin)
  • Find and fix the underlying issue in openblas. :slight_smile:

If you have time, it would help if you could try those (the last one gives bonus points).

1 Like

@krum_sv Just wanted to ping you, if you had time to try those. I can’t reproduce this issue locally, so it would be great to hear if my hunch is correct. I’m pretty sure you’re not the only one having this issue. :wink:

Sorry, the last days I had too many other task on my desk - but today I gave your suggestions a try :wink: :

So, the bug is fixed with the following setup:

  • pm.sample(...,mp_ctx='spawn')
  • pm.sample(...,mp_ctx='forkserver')

But the bug still present with the following setup:

  • import os; os.environ['OPENBLAS_NUM_THREADS'] = '1'
  • conda install libblas=*=*mkl

However, I also observed today that I cannot reproduce the bug reliably: I started today with a fresh JupyterLab instance and the bug was not present anymore (without any changes in the setup) - But after opening another medium-sized notebook (~3MB) the bug was back again. Afterwards I was not able to reproduce the bug-less state again, even after restarting Jupyter. And for just the record: The bug is not only present inside Jupyter notebooks, but also if I run the code as a plain skript.