Hi, I can’t figure out how to make inference about latent parameters on new data in a hierarchical model.
First I simulated some data:
MU0 = 100
SIGMA0 = 10
MU1 = 5
NUM_USERS = 18
NUM_TRAIN_USERS = 10
MEAN_NUM_OBS = 5
def expand_list_of_list(lol):
return [x for y in lol for x in y]
user_mean = pm.draw(pm.Normal.dist(MU0, SIGMA0), draws=NUM_USERS, random_seed=1)
user_sd = pm.draw(pm.HalfNormal.dist(MU1), draws=NUM_USERS, random_seed=1)
num_obs = pm.draw(pm.Poisson.dist(MEAN_NUM_OBS), draws=NUM_USERS, random_seed=1) + 1
sim_pd = (
pd.DataFrame({
'user_id': expand_list_of_list([[i] * x for i, x in enumerate(num_obs)]),
'mean': expand_list_of_list([[x] * y for x, y in zip(user_mean, num_obs)]),
'sigma': expand_list_of_list([[x] * y for x, y in zip(user_sd, num_obs)]),
'value': expand_list_of_list([pm.draw(pm.Normal.dist(x, y), draws=z, random_seed=1) for x, y, z in zip(user_mean, user_sd, num_obs)]),
})
)
train_pd = sim_pd.query(f"user_id < {NUM_TRAIN_USERS}")
test_pd = sim_pd.query(f"user_id >= {NUM_TRAIN_USERS}").reset_index(drop=True)
Then I constructed a hierarchical model like this using the training data:
with pm.Model(coords={'user_id': sorted(list(train_pd['user_id'].unique())), 'user_id_list': train_pd['user_id']}) as m:
user_mu = pm.Normal('user_mu', MU0, dims='user_id')
user_sigma = pm.HalfNormal('user_sigma', SIGMA0, dims='user_id')
value = pm.Data('value', train_pd['value'], dims='user_id_list')
y = pm.Normal('y', mu=user_mu[train_pd['user_id']], sigma=user_sigma[train_pd['user_id']], observed=value, dims='user_id_list')
idata = pm.sample()
That looked okay, but if I try to make inference on the new data like this:
with m:
pm.set_data(
{'value': test_pd['value']},
coords={'user_id': sorted(list(test_pd['user_id'].unique())), 'user_id_list': test_pd['user_id']}
)
idata.extend(pm.sample_posterior_predictive(idata))
I’d get this error:
ValueError: shape mismatch: objects cannot be broadcast to a single shape. Mismatch is between arg 0 with shape (45,) and arg 1 with shape (65,).
Can I get some help on how I can fix this? Thank you.