I have a model that I am sampling from as follows:
with m.model:
m.trace = pm.sample(cores=1, chains=1)
This works fine (well, there are divergences, so not fine, but it works).
But when I remove the cores=1, chains=1, I get a bad initial energy error:
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [weights, err_sd, medium influences, AND Output, βod2, βod1, βtemp2, βtemp1]
Sampling 4 chains: 0%| | 0/4000 [00:00<?, ?draws/s]/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
Bad initial energy, check any log probabilities that are inf or -inf, nan or very small:
Series([], )
Any idea what could be causing this? Is this because the parallel sampling version has multiple different starting points?
This seems particularly unfortunate, because the way I have been debugging such errors is to use only a single chain and a single core, to simplify things.
But I have been running check_test_point on the model before sampling and don’t see any problems:
βtemp1 -0.92
βtemp2 2.08
βod1 -0.92
βod2 -0.23
AND Output_interval__ -4.27
medium influences -19.35
err_sd_log__ -880.76
weights_stickbreaking__ -3.20
obs -2744.15
Name: Log-probability of test_point, dtype: float64
I suspect that this might be because the error condition is triggered by this:
start = self.integrator.compute_state(q0, p0)
if not np.isfinite(start.energy):
I believe we have had a discussion about this to the effect that it might not just be the point that has an infinite logp, but that the gradient might be involved, too. I know@junpenglao has posted about this (and how to check the value of the gradient in the debugger) but after about 30 minutes of keyword-searching discourse.pymc.io, I just cannot find the post.