Numpyro Convergence Diagnostics

I’ve swapped to numpyro and jax from sampling, however I don’t see any mention of divergences. Does sampling.jax.sample_numpyro_nuts run the checks the regular pymc.sample does?

sample_numpyro_nuts does track divergences and other sampling metrics, though it is true that they are not reported on the fly. You can find them in the resulting InferenceData object, in the sampling_stats attribute. For example, to get a count of divergences, you could write:

trace.sample_stats['diverging'].sum()

4 Likes