I see it now. I need to include the “clusters” dimension in my \alpha (intercept), but to make it work with the \mu, I need to calculate its mean. Possibly still some tweaking but this is essentially what I was hoping to achieve:
example_data = pd.DataFrame(data = {'BMI': np.array([13.43666, 16.061496, 22.998563, 14.499094, 18.248859, 13.811637, 15.061559, 19.873758, 15.436535, 13.186676]),
'WTGAIN': np.array([20, 39, 99, 31, 47, 99, 99, 1, 34, 29]),
'M_Ht_In': np.array([12, 12.8, 12.2, 12.2, 12.8, 13.8, 12.4, 13, 12.8, 13.4]),
'CLUSTERS': np.array([3, 3, 2, 3, 3, 2, 2, 3, 3, 3]),
'DBWT': np.array([2610, 3190, 3232, 2410, 2780, 3033, 2495, 3518, 3381, 2693])})
cluster_index, cluster_name = pd.factorize(example_data['CLUSTERS'])
with pm.Model() as Hier:
# add coordinates
Hier.add_coord("num_cols", ['BMI','WTGAIN','M_Ht_In'])
Hier.add_coord("obs", example_data.index)
Hier.add_coord("clusters", ["cluster_3", "cluster_2"] )
# add data containers
X = pm.MutableData("X", example_data[['BMI','WTGAIN','M_Ht_In']].values)
y = pm.MutableData("y", example_data['DBWT'].values)
# hyper-prior
αμ = pm.Normal("αμ", mu=0., sigma=3.,)
ασ = pm.HalfNormal("ασ", sigma=3.,)
βμ = pm.Normal("βμ", mu=0., sigma=3.,)
βσ = pm.HalfNormal("βσ", sigma=3.,)
# error
ϵ = pm.Normal("ϵ", mu=0, sigma=3.)
# using the hyperpior 'δ'
α = pm.Normal("α", mu=αμ, sigma=ασ, dims="clusters")
β = pm.Normal("β", mu=βμ, sigma=βσ, dims= ("clusters", "num_cols"))
# likelihood
μ = (X[cluster_index, :] * β[cluster_index, :]).sum(axis=1) + α[cluster_index].mean()
# μ.shape.eval(), X.shape.eval(), β.shape.eval(), α.shape.eval()
# response
yhat = pm.Normal("yhat", mu=μ, sigma=ϵ, observed=y)
# sample
trace = pm.sample(
250,
tune=50,
chains=4,
return_inferencedata=True,
idata_kwargs={'log_likelihood':True}
)