In my opinion, this is one of the places where pm.Data
, xarray and combinations of both shine.
To begin with, you can use named dimensions to indicate that the shapes of intercept and x_coeff are not both 10 by coincidence but because they represent the same dimension. Moreover, you can
with pm.Model(
# note you can use the actual group names (if they exists)
# and use those to subset the data afterwards
coords={"group": np.arange(10), "obs_id": np.arange(len(x))}
) as hierarchical_model:
indx = pm.Data("indx", training_sample["indx"], dims="obs_id")
x_data = pm.Data("x_data", x, dims="obs_id")
# Hyperpriors for group nodes
mu_intercept = pm.Normal("mu_intercept", mu=0.0, sigma=20)
sigma_intercept = pm.HalfNormal("sigma_intercept", sigma=10)
mu_x = pm.Normal("mu_x", mu=0.0, sigma=20)
sigma_x = pm.HalfNormal("sigma_x", sigma=10)
# Define priors
intercept = pm.Normal('Intercept', mu=mu_intercept, sigma=sigma_intercept, dims="group")
x_coeff = pm.Normal('x', mu=mu_x, sigma=sigma_x, dims="group")
mu = pm.Deterministic('mu', function((intercept[indx] + x_coeff[indx] * x_data)), dims="obs_id")
...
note that the obs_id
instances are only informative, the shape is set by the input arrays, the dims only label those.
After sampling, you’ll already have mu
in the posterior (mu is a deterministic computation from the posterior variables, so it’s not really a posterior predictive variable, which is why as you probably have noticed, you need to explicitly specify it for it to be sampled by pm.sample_posterior_predictive
). The problem as you have noticed is that it has obs_id
dimension instead of group
one. It has basically the same info, only 10 unique values, one per group, but they are repeated to match the group of each observation.
If you want the group effects at a given x (say x=2.34), you can use pm.sample_posterior_predictive
:
with hierarchical model:
pm.set_data({"x_data": [2.34]*10, "indx": np.arange(10)})
pp = pm.sample_posterior_predictive(trace...)
but you probably want multiple x values per group so you can plot group level hdis and so on. One option is to do something like we did above but the other way around in a sense, looping over groups and using pm.set_data({"x_data": np.linspace..., "indx": [group]*len(linspace)})
. We can change the length of the dimension obs_id
but we can’t change the number of dimensions of the variables x_data and indx, they are 1d arrays.
Another option is to use xarray which will handle most of the broadcasting and aligning of arrays:
post = trace_hierarchical.posterior
x_da = xr.DataArray(np.linspace(), dims=["x_plot"])
mu = post["Intercept"] + post["x_coeff"] * x_da
# mu is a 4d array with dims chain, draw, group, x_plot
# you can select and plot a group hdi with
az.plot_hdi(x_da, mu.sel(group=group_name), ...)
more details and working examples (I haven’t run any of the code above and it won’t run, some parts are pseudocode) of a related question at 11.16 Rethinking Code - #8 by OriolAbril