You could for instance rewrite the regression like this:
coords = {"group": group_list}
with pm.Model(coords=coords) as hierarchical:
x = pm.MutableData("x", data.x, dims="obs_id")
group_idx = pm.MutableData("group_idx", data.group_idx, dims="obs_id")
intercept = pm.Normal("intercept", sigma=1)
group_sigma = pm.HalfNormal("group_sigma", sigma=2)
group_offset = pm.Normal("group_offset", dims="group")
group_effect = pm.Deterministic("group_effect", group_sigma * group_offset, dims="group")
x_effect = pm.Normal("x_effect")
x_group_sigma = pm.HalfNormal("x_group_sigma", sigma=2)
x_group_offset = pm.Normal("x_group_offset", dims="group")
x_group_effect = pm.Deterministic("x_group_effect", x_group_sigma * x_group_offset, dims="group")
mu = (
intercept
+ group_effect[group_idx]
+ x_effect * x
+ x_group_effect[group_idx] * x
)
sigma = pm.HalfNormal("sigma", sigma=2)
pm.Normal("y", mu=mu, sigma=sigma, observed=data.y, dims="obs_id")
def generate():
# <--- I changed the groups
group_list = [f"group_{i}" for i in range(500)]
trials_per_group = 20
group_intercepts = rng.normal(0, 1, len(group_list))
# <--- I changed the group slopes to not be all exactly identical
group_slopes = rng.normal(-0.5, 0.1, size=len(group_list))
group_mx = group_intercepts * 2
group = np.repeat(group_list, trials_per_group)
subject = np.concatenate(
[np.ones(trials_per_group) * i for i in np.arange(len(group_list))]
).astype(int)
intercept = np.repeat(group_intercepts, trials_per_group)
slope = np.repeat(group_slopes, trials_per_group)
mx = np.repeat(group_mx, trials_per_group)
x = rng.normal(mx, 1)
y = rng.normal(intercept + (x - mx) * slope, 1)
data = pd.DataFrame({"group": group, "group_idx": subject, "x": x, "y": y})
return data, group_list
I changed the data generation slightly so that not all groups have the exact same slope, and so that we have 500 groups. The model for the most part a reparametrization, I removed the different sigma values per group though. That wasn’t in the data generation anyway, and sounds a bit strange to me.
The default sampler still struggels a little bit with this (somewhat low ess), but nutpie for instance seems to be perfectly fine, and samples in ~5s. The lowest ess is ~400
import nutpie
compiled = nutpie.compile_pymc_model(hierarchical)
tr = nutpie.sample(compiled)
ess = arviz.ess(tr)
ess.min()
Some more things that could further improve things:
- Use a ZeroSumNormal where appropriate
- Normalize the predictor x
Depending on dataset it is also possible that some centered parametrization improves things (when datasizes get bigger, it is not unusual that those get better)
Edit I’m also a bit confused by the bias right now. I’ll have another look tomorrow.