InferenceData is missing details about divergent transitions?

pm.sample() can optionally return a trace as an arviz InferenceData object instead of a MultiTrace, and soon that will be in the default behavior. InferenceData is better than MultiTrace in several ways. But it seems to be missing a key feature.

When sampling in NUTS encounters divergent transitions, the details about the divergences are captured in the MultiTrace object. For example, consider the Devil’s Funnel (ht: @AlexAndorra):

with pm.Model() as m:
    v = pm.Normal("v", 0.0, 3.0)
    x = pm.Normal("x", 0.0, pm.math.exp(v))
                  
    trace = pm.sample(random_seed=19950526)

This model produces many divergent transitions. We can see the details of these transitions in the MultiTrace object by examining trace.report._chain_warnings:

Note that the leapfrog source and destination are captured, the source and destination that result in the overly large energy change. This information is useful for determining why the divergence occurred. Of course for the Devil’s Funnel, it’s obvious why the divergences occurred. But sometimes it is not so obvious. Here is a plot from another (more complex) model that shows the leapfrog source and destinations of the divergent transitions:

Without the detail to draw this plot, it is difficult to diagnose these divergences.

As far as I know, this kind of divergent leapfrog detail is not captured in the corresponding InferenceData object. Or at least I cannot find it:

Is there a plan to capture this super-useful divergent detail in InferenceData?

2 Likes

I second this concern. I switched from return_inferencedata=False to return_inferencedata=True so I can do more analyses on the returned InferenceData object, but since I don’t have attribute report in the output of pm.sample, I can’t check for convergence issues explicitly. Is there a way to retrieve report by using return_inferencedata=True? This is an unfortunate omission.

My temporary (and not-so-elegant) workaround is to copy pm.sampling.sample to a “utilities” module in my project, rename the function as pm_sample, and replace its last lines with:

    if return_inferencedata:
        return idata, trace
    else:
        return None, trace

So in my code, I call:

inf_data, trace = utils.pm_sample(..., return_inferencedata=True, ...)

Can one of you open a request issue on our GitHub?

This would be part of the first goal in v4 <-> InferenceData integration · Issue #5160 · pymc-devs/pymc · GitHub, you can also leave a comment there instead of opening a new issue. But if you do open a new one please link to this one.

1 Like

Thanks! Added a comment there.

2 Likes