How to use stored divergences in nutpie.sample()?

My rather complex model sees a 5X speedup in sampling using nutpie instead of out-of-the-box NUTS. Thank you @aseyboldt . And thank you @cluhmann for alerting me to nutpie, by informing me of @AlexAndorra’s most excellent podcast.

Unfortunately my model diverges on some samples, about 1-2% of them. nutpie.sample() takes an optional store_divergences argument, which would seem to apply. The resultant trace has a sample_stats field with four new attributes: divergence_start, divergent_start_gradient, divergence_end, and divergence_momentum, and values for each of the divergent samples across the 344 RVs. I understand what these attributes mean, but I can’t figure out how to use them to locate the complex geometry in the model. How to narrow down which combination of the RVs are creating the problem?

Perhaps I am missing something obvious. It would not be the first time.

For very complex models I find that divergences are sometimes just what you have to live with. Step one for me is to check the traces and make sure they look nice, there’s good mixing between chains, and the posteriors are mostly normal-ish without any horrible lumps.

If there are problems, the next step is to go to az.plot_pair and look at pairwise scatterplots between posterior variables. You can set divergences=True to see which samples diverged and which are ok. The goal here is to 1) look for local identification issues (any plots that reveal some deterministic relationship between two variables), 2) look for patterns in the divergences (concentrated above/below a certain value or in a certain area of the parameter space).

Another thing I like to do is plot the priors against the posteriors and check what is being informed by the data and what isn’t. When there are priors that are not informed by the data at all, it can be important go back and set more informative priors (especially if the parameter has a scientific interpretation).

I’ll confess I’ve never used the more sophisticated diagnostics returned by nutpie, so I’ll leave comment on those specifically to Adrian.

1 Like

For very big models, it can be tough to know which parameters are suspect. So a lot of diagnostic aren’t terribly useful until you’ve narrowed down the problem to a group of parameters. I find two things work okay:

  1. Sort variables by ess. The low ess variables are typically the ones with the problems. This tends to work if at least some of the parameters are independent of each other. If two parameters are heavily correlated, then bad ess on one will bring the other down with it.

  2. sample-stats – Nutpie Sort variables for high variance in the tails of the mass matrix. The mass matrix will settle into a stationary distribution during tuning. The parameters that are hard to sample from often won’t.

3 Likes

This is what we actually do, but I don’t know that anyone’s ever written that down in any of our documentation for Stan!

This is a good heuristic, but you have to watch out for cases like hierarchical scale parameters, which typically mix slowly compared to lower-level parameters, despite being closely linked in value by construction. But then these might not have high Pearson correlation since that’s only getting at a linear relationship.

This is what we hope will happen, but if that doesn’t happen, it might be more helpful to look at the variance of the draws. Looking at variances can be tricky because their scales are often tied to the scale of means (imagine changing a regression covariate’s units from millimeters to meters, for example—the errors in the estimated coefficients will scale inversely with the means).

Here’s a little simulation to illustrate @daniel-saunders-phil’s point.

The python program generates a deterministic Markov chain x from a sinusoid over just a few cycles, so very low ESS. Then it generates a second chain y centered around the values in x with normal noise of scales 0.1, 1, and 10. You’ll see that as the noise goes up, the linear correlation goes down and the ESS decouples, just as @daniel-saunders-phil said it would.

This is just a simulated version of what happens when fitting a (centered) hierarchical model. The mean of the two chains is identical because the second chain is just a noisy value of the first, but the correlations have varying values depending on how much noise there is in the second. When the variance in y is large compared to the variance in x, you see the ESS get out of line and correlation go down, even though they are linked in exactly the same way, just at different scales.

import numpy as np
import pandas as pd
import arviz as az
import xarray as xr

def generate_series(n_points, cycles, noise_sd, seed=123):
    rng = np.random.default_rng(seed)
    t = np.arange(n_points) / n_points
    x = np.sin(2.0 * np.pi * cycles * t)
    y = x + rng.normal(loc=0.0, scale=noise_sd, size=n_points)
    return x, y

def build_idata(x, y):
    x_chain = x[None, :]
    y_chain = y[None, :]
    ds = xr.Dataset(
        data_vars={"x": (("chain", "draw"), x_chain), "y": (("chain", "draw"), y_chain)},
        coords={"chain": [0], "draw": np.arange(x.shape[0])},
    )
    return az.InferenceData(posterior=ds)

def main():
    n_points = 10000
    cycles = 1.5
    for noise_sd in {0.1, 1, 10}:
        x, y = generate_series(n_points=n_points, cycles=cycles, noise_sd=noise_sd, seed=2025)
        print(f"NOISE SD: {noise_sd}\n")
        print(f"CORRELATION: {np.corrcoef(x, y)[0, 1]:0.2f}\n")
        idata = build_idata(x, y)
        summary_df = az.summary(idata)
        print(summary_df)

main()

And here’s the output:

$ python sin-corr.py
NOISE SD: 0.1

CORRELATION: 0.99

arviz - WARNING - Shape validation failed: input_shape: (1, 10000), minimum_shape: (chains=2, draws=4)
    mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
x  0.212  0.675  -0.960     1.00       0.27    0.092       7.0      40.0    NaN
y  0.213  0.681  -1.013     1.08       0.27    0.092       7.0      45.0    NaN
NOISE SD: 1

CORRELATION: 0.55

arviz - WARNING - Shape validation failed: input_shape: (1, 10000), minimum_shape: (chains=2, draws=4)
    mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
x  0.212  0.675   -0.96    1.000      0.270    0.092       7.0      40.0    NaN
y  0.216  1.207   -2.09    2.455      0.262    0.040      22.0     119.0    NaN
NOISE SD: 10

CORRELATION: 0.06

arviz - WARNING - Shape validation failed: input_shape: (1, 10000), minimum_shape: (chains=2, draws=4)
    mean      sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
x  0.212   0.675  -0.960    1.000      0.270    0.092       7.0      40.0    NaN
y  0.252  10.072 -18.596   18.947      0.106    0.073    9097.0    9021.0    NaN

P.S. Thanks to GPT-5 for help with ArviZ—it’s a game changer for me as I can never remember the incantation in that build_idata code. I did review and tighten the final code up a bit and added the correlation output.

1 Like

These suggestions are all valuable. But I remain interested in how the extra attributes collected by nutpie can be used: divergence_start, divergence_start_gradient, divergence_end and divergence_momntum.

Perhaps a batsignal is needed to alert @aseyboldt ?

Perhaps a batsignal is needed to alert @aseyboldt?

Not sure why the batsignal worked, but it seems it did. :slight_smile:

I don’t really know of anything useful you could do with divergence_momentum, unless you are debugging issues in the sampler itself.

A divergence really happens during a leapfrog step. divergence_start is the location in (unconstrained) parameter space where a diverging leapfrog step started, and divergence_end is the location where it ended.

divergence_end can be useful if the problem in not actually posterior curvature, but numerics of some kind. If you have for instance underflows or overflows and end up with infs or nans in the logp or gradient, divergence_end will contain the location in the parameter space where that happened.

divergence_start can be more useful for debugging the source of curvature-related divergences than using the diverging draw itself, because the diverging draw is really a draw from the trajectory that contained a divergence, and doesn’t have to be close to the divergence.

Here is an example of how to access this: sample-stats – Nutpie

If you use those to find the problems that lead to the divergences I would recommend setting max_energy_error to something much smaller than the default 1000. Maybe 10 or so. That will typically give you a much larger number of divergences, and the divergences will also be closer to the problematic region in the parameter space.


It can be tricky to use in a high dimensional model though. I often just look at a couple of pairwise scatter plots if I have suspicions about certain parameters.

If you want to be fancy, you can also train a classifier (from sklearn for instance) to distinguish divergence locations from draws. The feature weights then often point to interesting parameters.

3 Likes