In a big model that I’m working with, I started to run into 100% divergences when I added some observations. It took me a long while to identify the problem, and in the process I came across something that I found very confusing and, at least in my opinion, worth sharing here.
Consider this simple model that I’ll sample using nutpie
directly:
import numpy as np
import nutpie
import pymc as pm
from matplotlib import pyplot as plt
n1 = 150
n2 = 7
coords = {
"dim1": range(n1),
"dim2": range(n2),
}
observed = np.random.normal(size=(n1, n2))
observed[:, 0] = 0
with pm.Model(coords=coords) as model:
sigma = pm.HalfNormal("sigma", dims=["dim2"])
obs = pm.Normal(
"obs", mu=0, sigma=sigma, dims=["dim1", "dim2"], observed=observed
)
compiled = nutpie.compile_pymc_model(model)
trace = nutpie.sample(
compiled,
tune=500,
draws=100,
chains=4,
seed=1234,
return_raw_trace=True,
store_mass_matrix=True,
)
This model has a single random variable, sigma
, that has dim2
shape. The observations for the first entry of dim2
are all set to 0. Doing this, causes 100% divergent samples.
With nutpie’s raw trace, I can grab a bunch of sample stats like the inverse mass matrix or the gradients.
t1 = trace[0][1].to_pandas()
keys = t1[0].keys()
stats = {key: [] for key in keys}
for row in t1:
for key, val in row.items():
stats[key].append(val)
for key, val in stats.items():
stats[key] = np.stack(val)
And if you plot the inverse mass matrix like this:
plt.semilogy(stats["mass_matrix_inv"])
plt.ylabel("Inverse mass matrix diag")
plt.xlabel("draw")
you get
The step sizes all tend to 0:
Just to confirm, the inverse mass matrix that is orders of magnitude bigger than the rest comes from the sigma that had all of its observations set exactly equal to 0
> compiled._coords["unconstrained_parameter"]
Index(['sigma_log___0', 'sigma_log___1', 'sigma_log___2', 'sigma_log___3',
'sigma_log___4', 'sigma_log___5', 'sigma_log___6'],
dtype='object')
I can understand why having huge inverse mass matrix diagonal entries causes 100% divergences (any small momentum change produces huge displacement and diverges, just like pushing a very light particle around), but I don’t understand why the mass matrix entry does this for the half normal in this case. It would be awesome if anyone here had some insight into the cause.
— Edit —
I just realized that I could get the plots working with inference data objects instead by doing this:
trace = nutpie.sample(
compiled,
tune=500,
draws=100,
chains=4,
seed=1234,
return_raw_trace=False,
store_mass_matrix=True,
save_warmup=True,
)
trace.warmup_sample_stats.mass_matrix_inv.isel(chain=0).plot(x="draw", hue="unconstrained_parameter")
which produces this