Sampling with blackjax - no divergence report

Hi all,

I’m running a pymc model using the blackjax sampler (like so: pm.sample(nuts_sampler="blackjax"). This seems to work fine, however, I am wondering about divergences as these seem to not be reported as is the case when using the standard nuts sampler. If there is no warning, can I assume that no divergences occurred? Or are these generally not reported?


Neither of the JAX samplers report divergences after they finish MCMC. You can check the idata by looking at idata.sample_stats.diverging (take the sum to get the total number of divergences, or you can look by chain), or by inspecting plots that report divergences (trace, pair…).

1 Like