Working with a simple hierarchical gamma model

```
import aesara.tensor as at
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
```

Here is how we generate data:

```
rng = np.random.default_rng(42)
# Hyperparameters
p_true = 6.
q_true = 4.
γ_true = 15.
# Number of subjects
N = 500
# Subject level parameters
ν_true = pm.draw(pm.Gamma.dist(q_true, γ_true, size=N), random_seed=rng)
# Number of observations per subject
x = rng.poisson(lam=2, size=N) + 1
idx = np.repeat(np.arange(0, N), x)
# Observations
z = pm.draw(pm.Gamma.dist(p_true, ν_true[idx]), random_seed=rng)
```

The following models on individual observations samples alright (ignore the bad priors unless they explain everything…):

```
with pm.Model() as m1:
p = pm.HalfFlat("p")
q = pm.HalfFlat("q")
γ = pm.HalfFlat("γ")
ν = pm.Gamma("ν", q, γ, size=N)
pm.Gamma("z", p, ν[idx], observed=z)
trace1 = pm.sample(random_seed=rng)
az.summary(trace1, var_names=["p", "q", "γ"])
```

```
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
p 6.038 0.264 5.559 6.528 0.023 0.016 132.0 289.0 1.02
q 3.939 0.304 3.400 4.521 0.009 0.007 1029.0 1253.0 1.00
γ 15.022 1.514 12.183 17.800 0.092 0.065 274.0 630.0 1.01
```

Now we compute a summary of the individual observations:

```
df = pd.DataFrame(data={"z": z, "id": idx})
z_sum = df.groupby("id").sum()["z"].values
z_mean = df.groupby("id").mean()["z"].values
```

And taking advantage of some Gamma properties write the following (correct?) model:

```
with pm.Model() as m2:
p = pm.HalfFlat("p")
q = pm.HalfFlat("q")
γ = pm.HalfFlat("γ")
ν = pm.Gamma("ν", q, γ, size=N)
pm.Gamma("z_sum", p*x, ν, observed=z_sum)
# pm.Gamma("z_mean", p*x, ν*x, observed=z_mean) # Equally bad
trace2 = pm.sample(random_seed=rng)
az.summary(trace2, var_names=["p", "q", "γ"])
```

```
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
p 36.292 32.486 2.928 83.056 22.676 19.106 2.0 11.0 2.34
q 3.634 0.622 2.743 4.763 0.385 0.310 3.0 33.0 1.84
γ 11.024 11.638 0.620 32.566 7.409 6.025 3.0 12.0 2.30
```

But it samples very poorly… why?