I’m not sure if this is the best place to report a potential bug in PyMC3 - But the following problem caused some headache to me during the last days until I could isolate the root cause:
I have a model setup which gets stuck during sampling (after initialization, before drawing the first sample) if more than 1 core is used and if Seaborn is imported and if more than 31 datapoints are used to fit a gaussian process.
This code is a minimal example to reproduce the issue:
import numpy as np
import pymc3 as pm
import matplotlib.pyplot as plt
import seaborn as sns # without this import it's working
n_data = 32 # n_data = 31 is working
x_data = np.linspace(0,10,n_data)
y_data = np.random.normal(size=n_data)
with pm.Model() as model:
eta = pm.HalfCauchy('eta',beta=0.5)
ls = pm.Gamma('ls',mu=1 ,sigma=0.7)
sigma = pm.HalfCauchy('sigma',beta=1)
cov = eta**2 * pm.gp.cov.Matern52(1,ls=ls)
gp = pm.gp.Latent(cov_func=cov)
f = gp.prior('f',x_data[:,None])
y = pm.Normal('y',mu=f,sigma=sigma,observed=y_data)
trace = pm.sample(cores=8) # cores=1 is also working
pm.traceplot(trace,compact=True)
plt.show()
pm.gp.util.plot_gp_dist(plt.gca(),trace['f'],x_data)
plt.plot(x_data,y_data)
plt.show()
The sampling is freezing at this point:
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (8 chains in 8 jobs)
NUTS: [f_rotated_, sigma, ls, eta]
0.00% [0/16000 00:00<00:00 Sampling 8 chains, 0 divergences]
After hitting Ctlr+C I get the following traceback:
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
/opt/conda/lib/python3.8/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
1485 with sampler:
-> 1486 for draw in sampler:
1487 trace = traces[draw.chain - chain]
/opt/conda/lib/python3.8/site-packages/pymc3/parallel_sampling.py in __iter__(self)
491 while self._active:
--> 492 draw = ProcessAdapter.recv_draw(self._active)
493 proc, is_last, draw, tuning, stats, warns = draw
/opt/conda/lib/python3.8/site-packages/pymc3/parallel_sampling.py in recv_draw(processes, timeout)
351 pipes = [proc._msg_pipe for proc in processes]
--> 352 ready = multiprocessing.connection.wait(pipes)
353 if not ready:
/opt/conda/lib/python3.8/multiprocessing/connection.py in wait(object_list, timeout)
930 while True:
--> 931 ready = selector.select(timeout)
932 if ready:
/opt/conda/lib/python3.8/selectors.py in select(self, timeout)
414 try:
--> 415 fd_event_list = self._selector.poll(timeout)
416 except InterruptedError:
KeyboardInterrupt:
During handling of the above exception, another exception occurred:
KeyboardInterrupt Traceback (most recent call last)
/opt/conda/lib/python3.8/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
1497 if callback is not None:
-> 1498 callback(trace=trace, draw=draw)
1499
/opt/conda/lib/python3.8/site-packages/pymc3/parallel_sampling.py in __exit__(self, *args)
524 def __exit__(self, *args):
--> 525 ProcessAdapter.terminate_all(self._samplers)
526
/opt/conda/lib/python3.8/site-packages/pymc3/parallel_sampling.py in terminate_all(processes, patience)
386 raise multiprocessing.TimeoutError()
--> 387 process.join(timeout)
388 except multiprocessing.TimeoutError:
/opt/conda/lib/python3.8/site-packages/pymc3/parallel_sampling.py in join(self, timeout)
341 def join(self, timeout=None):
--> 342 self._process.join(timeout)
343
/opt/conda/lib/python3.8/multiprocessing/process.py in join(self, timeout)
148 assert self._popen is not None, 'can only join a started process'
--> 149 res = self._popen.wait(timeout)
150 if res is not None:
/opt/conda/lib/python3.8/multiprocessing/popen_fork.py in wait(self, timeout)
43 from multiprocessing.connection import wait
---> 44 if not wait([self.sentinel], timeout):
45 return None
/opt/conda/lib/python3.8/multiprocessing/connection.py in wait(object_list, timeout)
930 while True:
--> 931 ready = selector.select(timeout)
932 if ready:
/opt/conda/lib/python3.8/selectors.py in select(self, timeout)
414 try:
--> 415 fd_event_list = self._selector.poll(timeout)
416 except InterruptedError:
KeyboardInterrupt:
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-2-4cd3f9bd1287> in <module>
15 f = gp.prior('f',x_data[:,None])
16 y = pm.Normal('y',mu=f,sigma=sigma,observed=y_data)
---> 17 trace = pm.sample(cores=8) # cores=1 is also working
18 pm.traceplot(trace,compact=True)
19 plt.show()
/opt/conda/lib/python3.8/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
543 _print_step_hierarchy(step)
544 try:
--> 545 trace = _mp_sample(**sample_args, **parallel_args)
546 except pickle.PickleError:
547 _log.warning("Could not pickle model, sampling singlethreaded.")
/opt/conda/lib/python3.8/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
1510 except KeyboardInterrupt:
1511 if discard_tuned_samples:
-> 1512 traces, length = _choose_chains(traces, tune)
1513 else:
1514 traces, length = _choose_chains(traces, 0)
/opt/conda/lib/python3.8/site-packages/pymc3/sampling.py in _choose_chains(traces, tune)
1528 lengths = [max(0, len(trace) - tune) for trace in traces]
1529 if not sum(lengths):
-> 1530 raise ValueError("Not enough samples to build a trace.")
1531
1532 idxs = np.argsort(lengths)[::-1]
ValueError: Not enough samples to build a trace.
For me I solved this issue by removing the Seaborn import, as I actually don’t need it anymore in this context. However, as many PyMC3 user are also using Seaborn, I think it might be a good idea to fix this bug somehow.
Version Information:
- pymc3 3.9.3
- theano 1.0.5
- seaborn 0.10.1
- conda 4.8.3
- Python 3.8.5
- Ubuntu 20.04 LTS