Diagnosing divergences when clipping is involved

I have a hierarchical model of homogeneous Poisson processes where I have many groups (3000+) of event sequences, which I assume come from the same population. I model the group Poisson rates using a simple population hierarchy with non-centered parametrisation:

coords = {"obs": data.index.values, "group": group_coords}

with pm.Model(coords_mutable=coords) as model:
    time_increments = pm.MutableData("time_increments", data.increment.values/365, dims="obs")
    groups = pm.MutableData("groups", data.group_ix, dims="obs")
    observed_counts = data.counts.values
    
    pop_mean =  pm.Normal("pop_mean", mu=10, sigma=4)
    pop_sigma = pm.Exponential("pop_sigma", lam=1)
    
    group_zscores = pm.Normal("group_z", mu=0, sigma=1, dims="group")
    group_means = pm.Deterministic("group_means", pop_mean + pop_sigma * group_zscores, dims="group")
    
    avg_increment_lambdas = pm.math.clip(group_means[groups], 0, float("inf"))
    total_increment_lambdas =  pm.Deterministic("total_increment_lambda", avg_increment_lambdas * time_increments, dims="obs")
    
    counts = pm.Poisson("counts", mu=total_increment_lambdas, observed=observed_counts, dims="obs")

with model:
    idata = pm.sample(chains=4, keep_warning_stat=True)

This produces many divergences:

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2579 seconds.
There were 118 divergences after tuning. Increase `target_accept` or reparameterize.

I was aware that I have a potential problem with the zero-clipping but I concluded it’s not an issue because, as far as I could tell, no samples were actually clipped (all group_means draws are positive):

idata.posterior["group_means"].min().values
array(0.10084822)

So I spent a lot of time trying other things (changing the target_accept, switching to a centred parametrisarion, even digging cluelessly through detailed divergence reports returned with keep_warning_stat=True) none of which worked. Eventually, for the lack of anything else to try, I decided to try to get rid of the clipping:

with pm.Model(coords_mutable=coords) as model2:
    time_increments = pm.MutableData("time_increments", data.increment.values/365, dims="obs")
    groups = pm.MutableData("groups", data.group_ix, dims="obs")
    observed_counts = data.counts.values
    
    log_pop_mean =  pm.Normal("pop_mean", mu=1.5, sigma=1)
    log_pop_sigma = pm.Exponential("pop_sigma", lam=10)
    
    group_means = pm.LogNormal("group_means", mu=log_pop_mean, sigma=log_pop_sigma, dims="group")
    
    avg_increment_lambdas = group_means[groups]
    total_increment_lambdas =  pm.Deterministic("total_increment_lambda", avg_increment_lambdas * time_increments, dims="obs")
    
    counts = pm.Poisson("counts", mu=total_increment_lambdas, observed=observed_counts, dims="obs")

with model2:
    idata2 = pm.sample(chains=4, keep_warning_stat=True)

And lo and behold!

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1460 seconds.

So my questions are:

  1. Am I right that the clipping was at fault? (As opposed to some other change I had to make in order to get rid of clipping, like setting slightly different priors for the population mean and variance, for example). EDIT: I am aware that now I technically have a centered parametrisation with the LogNormal. I had tried a centred parametrisation with clipping and that gave me divergences
  2. What, in general, is the main issue with clipping? Is it that it’s non-differentiable? Or that it forces you to sample near the boundary (in the same way as MCMC sampling of HalfNormal causes divergences)?
  3. Why did clipping cause issues in this particular case, considering that no samples were actually clipped?
  4. Was there any way to systematically diagnose that? A plot I could have done, some summary sampling statistic I could have looked at? Some anomaly I could have looked for in the sampling reports?

Divergent samples don’t get accepted so you won’t see them. Clipping causes the space below zero to be flat which nuts cant handle. The best is to soft clip like this function: tfp.bijectors.SoftClip  |  TensorFlow Probability

We should actually add it to PyMC

Hm, if they are discarded that would make sense, thank you! So when there is a diverging trajectory, the actual diverging sample is discarded and a new non-diverging one is drawn from earlier in the trajectory? But it’s still marked as diverging?

If so would it be correct to conclude that:

  1. The actual discarded divergent samples are available in the sampling reports?
  2. The divergent samples in the sampling reports should be negative?

I don’t really understand the structure of the sampling report (is it documented anywhere?) but here is what I attempted:

idata_warns = idata.sample_stats.where(idata.sample_stats["warning"] != None, drop=True)
warnings = idata_warns["warning"].sel(warning_dim_0=0).stack(sample=["chain", "draw"]).dropna(dim="sample")

divergence_point_source = [w.item().divergence_point_source for w in warnings]
divergence_point_dest = [w.item().divergence_point_dest for w in warnings]

This should get the details of the raw RVs as seen by the sampler. I then transform these values to obtain the group_means (which are the RVs that would actually be clipped):

div_group_means_source = [d['pop_mean'] + np.exp(d['pop_sigma_log__']) * d['group_z'] for d in divergence_point_source]
div_group_means_dest = [d['pop_mean'] + np.exp(d['pop_sigma_log__']) * d['group_z'] for d in divergence_point_dest]

Now if I understood it correctly, my expectation is that all entries in div_group_means_source should be positive and all entries in div_group_means_dest should be negative:

(np.array(div_group_means_source) < 0).any()
True

So far so good…

(np.array(div_group_means_dest) < 0).any(axis=1).sum()/len(div_group_means_dest)
0.8217054263565892

Ok, 82% of the divergences contains a negative group_means that would have needed to be clipped. So that’s mostly consistent with my understanding by above. But 82% is not 100%. Why would there be a divergence if there are no negative group_means in that draw? Isn’t clip identical to the identity function in that case?

To get into the weeds a little, HMC (and NUTS) simulates some dynamics where energy is conserved. The integrator is very good, but also can be (arbitrarily) bad. Luckily, you can just check how much the energy has changed from the initial conditions to tell how bad the integrator has done. We then do a Metropolis-Hastings correction using this energy (where the probability of acceptance is exp(-energy_change)) so that it is valid MCMC (i.e., converges to the stationary distribution).

A “divergence” is just when the energy_change is more than 1,000. That indicates that the integrator did very very poorly (note that exp(-1000) == 0 in float64, so the transition has no chance of being accepted). But a divergence is not being defined mathematically here.

Anyways, yeah, I think that probably if you start at a point and the sampled momentum sends your group_means negative, then the gradients might get weird and give you weird numbers, but hard to tell without checking.

1 Like

That’s very helpful, thank you @colcarroll!

I guess my main remaining concern/question is your point that it’s “hard to tell without checking”. I understand that you don’t know enough about the model to say anything more concrete but is there any “checking” that I can do on my side? I get that divergences are notoriously hard to debug (and as you explained based on somewhat arbitrary rules) but I am still hoping to develop some some best practices/intuitions in tackling them

There’s a few interesting questions here you might be asking:

  1. How can I sample from my model?

So, the best choice is to reparametrize your model, which is (unfortunately!) context dependent and maybe the hardest choice. One common thing to do is to center your model: in case you have something like

mu = pm.Something(...)
std = pm.SomethingElse(...)

latent = pm.Normal("latent", mu, std)

You might change that to

mu = pm.Something(...)
std = pm.SomethingElse(...)

latent_unscaled = pm.Normal("latent_unscaled", 0, 1)
latent = pm.Deterministic("latent", mu + std * latent_unscaled)

(similarly for other location-scale parameters).

Another thing to look for are “heavy tail” parameters, like Cauchy or StudentT with small degrees of freedom – can you replace those with gaussians? Maybe you can’t because then it is not The One True Model, and you’ll need to roll your own sampler!

  1. How can I tell why I’m getting divergences?

When you have a divergence, you just get a rejected sample. So you can see what parameters cause a rejected sample. I think @aseyboldt might have hooked up some tools to also see the momentum draws that lead to a divergence. Note that the trajectory is deterministic once you know the momentum, so you could (in theory) recreate that trajectory and look at the gradients/log densities at each point. That would require a lot of code, and you’d probably just be like “boy, those gradients sure changed fast! I bet there was some intense curvature in this part of parameter space!”, then go back to step 1. But maybe it would be more useful! And maybe you’d share that code back, which would be neat!

This is great, thank you. So to summarise:

  1. Re-parametrise (this is obviously the most well-documented thing to try and somewhat well understood in the case of Gaussians; less clear once you have non-Gaussians and non-real supports etc)
  2. Look out for distributions with a lot of mass near the boundary (e.g. HalfNormal; see here)
  3. Look out for transformations with regions of flat gradients (e.g. clipping)
  4. Look out for heavy-tailed distributions (e.g. Cauchy, StudentT; this one was sort of new to me since many many tutorials rely on this for “robust regression”)
  5. Write custom samplers (the nuclear button approach, I take it…)
  6. In most cases avoid deep dives into the source of individual divergences (though, in this specific case I found it informative that I could see negative values in the discarded samples whereas there were no negative values in the accepted samples)

On thing that hasn’t come up is looking at posterior pair plots. This is my first step when diagnosing divergences. I’m looking for 1) high pairwise correlations (plots that look like straight lines), and 2) clusters of divergences in the parameter space. If a cluster is at a boundary (like a truncation point) this is an obvious problem. If it’s somewhere else, it might indicate something higher dimensional.

I recommend this article by Michael Betancourt about diagnosing pathologies for a deep dive into this issue. Not much in the way of how to resolve it, but knowing how problems come about is still extremely helpful.

On thing that hasn’t come up is looking at posterior pair plots. This is my first step when diagnosing divergences.

That’s an excellent point and in fact the original reason for this thread. Indeed the pair plots were my first go-to but since I have 3000+ groups I couldn’t figure out how to actually inspect them. Clearly doing so for 3000+ pair plots wouldn’t have been feasible and I hoped that there is away of somehow “sorting” the groups by how problematic they are or otherwise drawing my attention to a hopefully smaller subset of the 3000+

In the end that was unnecessary when I realised that the clipping was at fault, but the original problem still stands. Especially since Betancourt advises that sometimes you might want to have some of your groups parametrised centrally and others non-centrally. If one were in that situation with 3000+ groups it’d be a challenge

In cases where I’ve had a lot of variables, I started by computing all the pairwise correlation coefficients then plotting the largest ones (in absolute value). Maybe not the most elegant solution, but it was the first thing that came to my mind at the time.

That’s a great suggestion! That’s for sure helpful with highly correlated variables; I guess it’s less likely to identify funnels, but still a great start!

Good point about funnels, not sure how to start tackling that. I wonder if you could train a classifier to look at a set of draws for a given pair of parameters and identify funnel candidates? @aseyboldt did some work on identifying degenerate parameters using gradient information as features in a classifier here. Not sure if it was ever fleshed out any further.

1 Like

Also Find divergence sources using feature importance · Issue #217 · pymc-devs/pymc-experimental · GitHub

2 Likes

Great suggestions, thank you!