Hi there! I’ve recently been trying out the jax-based samplers (blackjax in particular), but I’ve noticed some unexpected behavior when working with a model that has a TruncatedNormal
likelihood. PyMC’s default NUTS sampler seems to handle it fine, but when running with blackjax I get an explosion of divergences 9/10 times. I’ve distilled it down to a test case that recreates the issue:
with pm.Model() as model:
mu = pm.Normal("mu")
sigma = pm.HalfNormal("sigma", sigma=0.5)
y = pm.TruncatedNormal("y", mu=mu, sigma=sigma, lower=-10, upper=10, size=(N_OBSERVATIONS,))
prior_trace = pm.sample_prior_predictive()
Everything looks good with default NUTS:
data = prior_trace.prior.y.isel(chain=0, draw=0)
with pm.observe(model, {y: data}):
idata = pm.sample()
idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True)
Not so with blackjax:
data = prior_trace.prior.y.isel(chain=0, draw=0)
with pm.observe(model, {y: data}):
idata = pm.sample(nuts_sampler="blackjax")
idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True)
I wouldn’t call myself an expert on MCMC methods, but as far as I can tell, it seems like there shouldn’t be any geometry issues affecting this model. This has me wondering if there is a bug in the sampler implementation here, but I’d also be curious to know if anybody can spot something I might be doing wrong.
Thanks for the feedback!
Tech details
macOS 15
python 3.11.11
Key packages:
- pymc=5.20.0
- blackjax=1.2.4
- tensorflow-probability=0.24.0