Need help in plotting HDI from posterior_predictive when Y_obs has 2 dimensions

My posterior_predictive group has the following dimensions: chain:4, draws: 15000, Y_obs_dim_2: 3, Y_obs_dim_3: 250.

This Y_obs is a solution to a system of ordinary differential equations with y_obs_dim_2 referring to the # of species and Y_obs_dim_3 referring to the species concentration at different time points.

I would like to plot the HDI intervals for Y_obs_dim_3 for each species (or for each y_obs_dim_2). Basically, I would like to have 3 sets of plot, one for each species, showing the Y_obs and the HDI (similar to the plot_hdi example plots).

After computing for the HDI, I get the following:

Dimensions:      (Y_obs_dim_2: 3, Y_obs_dim_3: 249, hdi: 2)
  * Y_obs_dim_2  (Y_obs_dim_2) int64 0 1 2
  * Y_obs_dim_3  (Y_obs_dim_3) int64 0 1 2 3 4 5 6 ... 243 244 245 246 247 248
  * hdi          (hdi) <U6 'lower' 'higher'
Data variables:
    Y_obs        (Y_obs_dim_2, Y_obs_dim_3, hdi) float64 0.9429 1.043 ... 0.359

How do I plot_hdi for each Y_obs_dim_2? Or if I can’t use plot_hdi, how can I access the result for each index of Y_obs_dim_2?

You could loop over the coordinates in Y_obs_dim_2 and make the plots “by hand”. I like to plot HDIs with code roughly like this. If you can, I recommend re-writing your model to use named dimensions. I made a little fake dataset that I think looks like yours and sampled the posterior and posterior predictive. Here’s how I plot the mean and HDI of each animal:

animals = idata.posterior.coords['animal'].values
times = idata.posterior.coords['time'].values
posterior_mu = az.extract(idata, 'posterior_predictive').mean(dim=['sample']).y_hat
hdi = az.hdi(idata.posterior_predictive).y_hat

fig, axes = plt.subplots(1, 3, figsize=(14, 4), dpi=100)
for axis, animal in zip(axes, animals):
    axis.plot(times, posterior_mu.sel(animal=animal), label='Posterior Mean')
    axis.fill_between(times, *hdi.sel(animal=animal).values.T, alpha=0.25, label='94% HDI')
    axis.plot(times, df[animal].values, ls='--', color='k', label='Data')

Here’s the output: