Thanks for following up @OriolAbril!
The following is a code example to show what I came up with. It turns out that plot_forest will do what I was hoping for. I was also able to make a sort of ridge plot out of BPV plots, although it isn’t pretty. I ended up doing all of this by sampling posterior predictive 100 times for each draw, allowing me to later calculate a median value for each draw. I’m not sure if this step was necessary or not…
Make the example data with this code block:
import numpy as np
import pandas as pd
n_groups = 10
n_obs_min, n_obs_max = 5, 50
group_mean_min, group_mean_max = 0, 30
group_std_min, group_std_max = 1, 10
# make the data
np.random.seed(0)
df_ls = []
for i in range(n_groups):
n_obs = np.random.randint(n_obs_min, n_obs_max, 1)[0]
mean = np.random.uniform(group_mean_min, group_mean_max, 1)[0]
std = np.random.uniform(group_std_min, group_std_max, 1)[0]
df_ls += [pd.DataFrame({"y": np.random.normal(mean, std, n_obs), "group": i})]
df = pd.concat(df_ls).reset_index()
Now fit a simple unpooled model:
import pymc as pm
# fit an unpooled model
data = df.y.values
categories = df.group.unique().tolist()
idx = pd.Categorical(df.group, categories=categories).codes
coords = {"groups": categories, "groups_flat": df.group.values}
with pm.Model(coords=coords) as unpooled_model:
μ = pm.Normal("μ", mu=(group_mean_min + group_mean_max)/2, sigma=10, dims="groups")
σ = pm.HalfNormal("σ", sigma=10, dims="groups")
Y = pm.Normal("Y", mu=μ[idx], sigma=σ[idx], observed=data, dims="groups_flat")
idata = pm.sample(random_seed=0)
idata.posterior = idata.posterior.expand_dims(pred_id=100)
idata.extend(pm.sample_posterior_predictive(idata, extend_inferencedata=True, sample_dims=["chain", "draw", "pred_id"], random_seed=0))
Now make the forest plot with this code block:
import arviz as az
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
def make_forest_plot(idata, title):
post_pred = az.extract(idata, group="posterior_predictive", combined=False)
post_pred = post_pred.median(dim=["pred_id"])
# post_pred_df = post_pred.to_dataframe().droplevel(["chain", "draw"]).reset_index()
fig, ax = plt.subplots(figsize=(11.5, 5))
az.plot_forest(
post_pred,
kind="forestplot",
linewidth=4,
combined=True,
ridgeplot_overlap=1.5,
ax=ax
);
ax.set_title(title);
ax.xaxis.set_major_locator(MultipleLocator(5));
ax.xaxis.set_minor_locator(MultipleLocator(1));
ax.grid(True, axis="x", which="both");
make_forest_plot(idata, "Each Group's 94% HDI and IQR for its Median Value")
This produces the following plot:
Now produce a BPV ridgeline plot with this code block:
n_categories = len(categories)
fig = plt.figure(figsize=(11.5, 0.75*n_categories))
first_ax = None
for i, category in enumerate(categories):
if first_ax is None:
ax = fig.add_subplot(n_categories, 1, i + 1, frameon=False)
first_ax = ax
else:
ax = fig.add_subplot(n_categories, 1, i + 1, frameon=False, sharex=first_ax)
az.plot_bpv(idata.sel(groups_flat=category), kind="t_stat", t_stat="median", ax=ax);
ax.set_ylabel(category, rotation=0);
ax.set_title("");
ax.xaxis.set_major_locator(MultipleLocator(5));
ax.xaxis.set_minor_locator(MultipleLocator(1));
ax.grid(True, axis="x", which="both");
if i < n_categories - 1:
plt.tick_params('x', labelbottom=False)
if i == 0:
ax.set_title("Each Group's Bayesian P Value for its Median Value");

