I suggest you try to start with the simplest case and expand gradually. That should help you pinpoint the problem faster.
Here is a simplified example:
import pymc as pm
idx = [0, 0, 0, 1, 1, 1]
coords = {
"id": [0, 1],
"obs_idx": list(range(len(idx))),
}
with pm.Model(coords=coords) as m:
x = pm.Data("x", [0, 0, 0, 0, 0, 0], dims=("obs_idx",))
mu_group = pm.Normal("mu_group")
b = pm.Normal("b", mu_group, dims=("id",))
mu = b[idx] * x
obs = pm.Normal("obs", mu, observed=[0, 1, 2, 3, 4, 5], dims=("obs_idx",))
idata = pm.sample()
# Predict two new groups
idx = [0, 0, 1]
new_coords = {
"id": [2, 3],
"obs_idx": list(range(len(idx))),
}
with pm.Model(coords=new_coords) as pred_m:
x = pm.Data("x", [0, 0, 0], dims=("obs_idx",))
mu_group = pm.Normal("mu_group")
# Needs a new name, so old b is not used!
new_b = pm.Normal("new_b", mu_group, dims=("id",))
mu = new_b[idx] * x
obs = pm.Normal("obs", mu, dims=("obs_idx",))
idata = pm.sample_posterior_predictive(idata, var_names=["obs"], predictions=True, extend_inferencedata=True)