Gamma model sampling much worse when observation summaries are used instead of individual observations

Working with a simple hierarchical gamma model

import aesara.tensor as at
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm

Here is how we generate data:

rng = np.random.default_rng(42)

# Hyperparameters
p_true = 6.
q_true = 4.
γ_true = 15.

# Number of subjects
N = 500  
# Subject level parameters
ν_true = pm.draw(pm.Gamma.dist(q_true, γ_true, size=N), random_seed=rng)

# Number of observations per subject
x = rng.poisson(lam=2, size=N) + 1  
idx = np.repeat(np.arange(0, N), x)
# Observations
z = pm.draw(pm.Gamma.dist(p_true, ν_true[idx]), random_seed=rng)

The following models on individual observations samples alright (ignore the bad priors unless they explain everything…):

with pm.Model() as m1:
    p = pm.HalfFlat("p")
    q = pm.HalfFlat("q")
    γ = pm.HalfFlat("γ")
    
    ν = pm.Gamma("ν", q, γ, size=N)
    pm.Gamma("z", p, ν[idx], observed=z)
    
    trace1 = pm.sample(random_seed=rng)

az.summary(trace1, var_names=["p", "q", "γ"])
 	mean 	sd 	hdi_3% 	hdi_97% 	mcse_mean 	mcse_sd 	ess_bulk 	ess_tail 	r_hat
p 	6.038 	0.264 	5.559 	6.528 	0.023 	0.016 	132.0 	289.0 	1.02
q 	3.939 	0.304 	3.400 	4.521 	0.009 	0.007 	1029.0 	1253.0 	1.00
γ 	15.022 	1.514 	12.183 	17.800 	0.092 	0.065 	274.0 	630.0 	1.01

Now we compute a summary of the individual observations:

df = pd.DataFrame(data={"z": z, "id": idx})
z_sum = df.groupby("id").sum()["z"].values
z_mean = df.groupby("id").mean()["z"].values

And taking advantage of some Gamma properties write the following (correct?) model:

with pm.Model() as m2:
    p = pm.HalfFlat("p")
    q = pm.HalfFlat("q")
    γ = pm.HalfFlat("γ")

    ν = pm.Gamma("ν", q, γ, size=N)
    pm.Gamma("z_sum", p*x, ν, observed=z_sum)
    # pm.Gamma("z_mean", p*x, ν*x, observed=z_mean)  # Equally bad
    
    trace2 = pm.sample(random_seed=rng)

az.summary(trace2, var_names=["p", "q", "γ"])
 	mean 	sd 	hdi_3% 	hdi_97% 	mcse_mean 	mcse_sd 	ess_bulk 	ess_tail 	r_hat
p 	36.292 	32.486 	2.928 	83.056 	22.676 	19.106 	2.0 	11.0 	2.34
q 	3.634 	0.622 	2.743 	4.763 	0.385 	0.310 	3.0 	33.0 	1.84
γ 	11.024 	11.638 	0.620 	32.566 	7.409 	6.025 	3.0 	12.0 	2.30

But it samples very poorly… why?

1 Like

I think this is an identification problem. The pair plot of your second model is quite suggestive:

If I’m understanding correctly, you’re trying to exploit the property of gammas here, that given \{X_i\}_{i=1}^N \sim Gamma(\alpha_i, \beta), then \sum_{i=1}^k X_i \sim Gamma(\sum_{i=1}^N \alpha_i, \beta)? Then if \alpha_i = \alpha_j = \alpha\quad \forall i, j \in N \implies \sum_{i=1}^N X_i \sim Gamma(N\alpha, \beta)?

I think the problem is that when you try to leverage that, you end up trying to estimate 500 rate parameters from 500 observations, and you end up with the identification problem suggested in the pair plot.

Indeed, if you pin down the value of p to the true value, you recover the other two parameters:

with pm.Model() as m3:
    p = 6
    q = pm.HalfFlat("q")
    γ = pm.HalfFlat("γ")

    ν = pm.Gamma("ν", q, γ, size=N)
    pm.Gamma("z_sum", p*x, ν, observed=z_sum)
    # pm.Gamma("z_mean", p*x, ν*x, observed=z_mean)  # Equally bad
    
    trace3 = pm.sample(random_seed=rng)
     mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
q   3.916  0.292   3.367    4.472      0.006    0.004    2525.0    2532.0    1.0
γ  14.991  1.238  12.598   17.271      0.025    0.018    2465.0    2568.0    1.0
3 Likes

Thank you so much for the analysis. Yes that was the logic for the second model.

I guess my (lack) of intuition for the gamma parameters betrayed me. So the summary statistics are not sufficient to condition the individual parameters. Probably because the dispersion information is gone.

I wonder if there are ranges of number of observations/expected value of observations that are less ill-conditioned or if the gamme just stays too flexible to infer one parameter per individual.

It’s curious that the summary statistics are (extremely) locally sufficient, though. When p is set to precisely the true value, everything works totally fine. But I also tried setting p \sim N(6, 1), and even this caused the model to fail. I would have expected a tight prior around the true value to at least give something, but it didn’t.

I struggle with the notion of identification in Bayesian models. I am often stuck in a frequentist “you can’t have more parameters than observations” mindset, but this is not always the case, nor is it even the case here: the model estimates 502 parameters perfectly from 500 data points, given the true value of p. Things are evidently much more subtle in MCMClandia than in the People’s Republic of Matrix Inversion .

The functional form of the dependence between p and \gamma in the pair plot is also tantalizing. I wonder if one could do some algebra on the model to discover precisely what is going on here. A mystery for another day, I suppose.

1 Like