sample_numpyro_nuts
does track divergences and other sampling metrics, though it is true that they are not reported on the fly. You can find them in the resulting InferenceData
object, in the sampling_stats
attribute. For example, to get a count of divergences, you could write:
trace.sample_stats['diverging'].sum()