Blackjax sampler suffers divergences with TruncatedNormal likelihood

Hi there! I’ve recently been trying out the jax-based samplers (blackjax in particular), but I’ve noticed some unexpected behavior when working with a model that has a TruncatedNormal likelihood. PyMC’s default NUTS sampler seems to handle it fine, but when running with blackjax I get an explosion of divergences 9/10 times. I’ve distilled it down to a test case that recreates the issue:

N_OBSERVATIONS = 50

with pm.Model() as model:
    mu = pm.Normal("mu")
    sigma = pm.HalfNormal("sigma", sigma=0.5)
    y = pm.TruncatedNormal("y", mu=mu, sigma=sigma, lower=-10, upper=10, size=(N_OBSERVATIONS,))
    prior_trace = pm.sample_prior_predictive()
pm.model_to_graphviz(model)

Everything looks good with default NUTS:

data = prior_trace.prior.y.isel(chain=0, draw=0)
with pm.observe(model, {y: data}):
    idata = pm.sample()
    idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True)

Not so with blackjax:

data = prior_trace.prior.y.isel(chain=0, draw=0)
with pm.observe(model, {y: data}):
    idata = pm.sample(nuts_sampler="blackjax")
    idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True)

I wouldn’t call myself an expert on MCMC methods, but as far as I can tell, it seems like there shouldn’t be any geometry issues affecting this model. This has me wondering if there is a bug in the sampler implementation here, but I’d also be curious to know if anybody can spot something I might be doing wrong.

Thanks for the feedback!

Tech details

macOS 15
python 3.11.11

Key packages:

- pymc=5.20.0
- blackjax=1.2.4
- tensorflow-probability=0.24.0

Do you see the same problem with numpyro nuts sampler?

Also may want to seed that prior predictive draw to avoid stochastic behavior between runs, just while debugging

Do you see the same problem with numpyro nuts sampler?

Yes, it seems to suffer from the same problem.

Also may want to seed that prior predictive draw to avoid stochastic behavior between runs, just while debugging

Good call! I’ll seed it with pm.sample_prior_predictive(random_seed=100) for subsequent runs.

At first I suspected it may be that the logp is less stable on the JAX backend, but I’ve tried sampling with JAX both with the PyMC and nutpie samplers and those don’t diverge, so I think it’s just that numpyro/blackjax are doing a poorer job here.

Comparison gist: numpyro_diverging_truncated_model.ipynb · GitHub

You may want to post on the blackjax repo: GitHub · Where software is built

And or in the numpyro discourse: numpyro - Pyro Discussion Forum

CC @junpenglao

1 Like

Thanks for the quick response! I appreciate the robust level of support in this community :slight_smile:

That’s an interesting result - I’ll go ahead and try to re-raise this on the repo. I’m not very familiar with the implementation details involved, but hopefully a high level description of the problem will help them catch the potential bug.

For future reference: Blackjax sampler suffers divergences with TruncatedNormal likelihood · Issue #775 · blackjax-devs/blackjax · GitHub