I’ve swapped to numpyro and jax from sampling, however I don’t see any mention of divergences. Does sampling.jax.sample_numpyro_nuts
run the checks the regular pymc.sample
does?
sample_numpyro_nuts
does track divergences and other sampling metrics, though it is true that they are not reported on the fly. You can find them in the resulting InferenceData
object, in the sampling_stats
attribute. For example, to get a count of divergences, you could write:
trace.sample_stats['diverging'].sum()
4 Likes