pm.sample() can optionally return a trace as an arviz InferenceData object instead of a MultiTrace, and soon that will be in the default behavior. InferenceData is better than MultiTrace in several ways. But it seems to be missing a key feature.
When sampling in NUTS encounters divergent transitions, the details about the divergences are captured in the MultiTrace object. For example, consider the Devil’s Funnel (ht: @AlexAndorra):
with pm.Model() as m:
v = pm.Normal("v", 0.0, 3.0)
x = pm.Normal("x", 0.0, pm.math.exp(v))
trace = pm.sample(random_seed=19950526)
This model produces many divergent transitions. We can see the details of these transitions in the MultiTrace object by examining trace.report._chain_warnings:
Note that the leapfrog source and destination are captured, the source and destination that result in the overly large energy change. This information is useful for determining why the divergence occurred. Of course for the Devil’s Funnel, it’s obvious why the divergences occurred. But sometimes it is not so obvious. Here is a plot from another (more complex) model that shows the leapfrog source and destinations of the divergent transitions:
Without the detail to draw this plot, it is difficult to diagnose these divergences.
As far as I know, this kind of divergent leapfrog detail is not captured in the corresponding InferenceData object. Or at least I cannot find it:
Is there a plan to capture this super-useful divergent detail in InferenceData?