Arvix plot_forest for Posterior Predictive Median Values?

Thanks for following up @OriolAbril!

The following is a code example to show what I came up with. It turns out that plot_forest will do what I was hoping for. I was also able to make a sort of ridge plot out of BPV plots, although it isn’t pretty. I ended up doing all of this by sampling posterior predictive 100 times for each draw, allowing me to later calculate a median value for each draw. I’m not sure if this step was necessary or not…

Make the example data with this code block:

import numpy as np
import pandas as pd

n_groups = 10
n_obs_min, n_obs_max = 5, 50
group_mean_min, group_mean_max = 0, 30
group_std_min, group_std_max = 1, 10

# make the data
np.random.seed(0)
df_ls = []
for i in range(n_groups):
    n_obs = np.random.randint(n_obs_min, n_obs_max, 1)[0]
    mean = np.random.uniform(group_mean_min, group_mean_max, 1)[0]
    std = np.random.uniform(group_std_min, group_std_max, 1)[0]
    df_ls += [pd.DataFrame({"y": np.random.normal(mean, std, n_obs), "group": i})]
df = pd.concat(df_ls).reset_index()

Now fit a simple unpooled model:

import pymc as pm

# fit an unpooled model
data = df.y.values
categories = df.group.unique().tolist()
idx = pd.Categorical(df.group, categories=categories).codes
coords = {"groups": categories, "groups_flat": df.group.values}
with pm.Model(coords=coords) as unpooled_model:
    μ = pm.Normal("μ", mu=(group_mean_min + group_mean_max)/2, sigma=10, dims="groups")
    σ = pm.HalfNormal("σ", sigma=10, dims="groups")
    Y = pm.Normal("Y", mu=μ[idx], sigma=σ[idx], observed=data, dims="groups_flat")
   
    idata = pm.sample(random_seed=0)
    idata.posterior = idata.posterior.expand_dims(pred_id=100)
    idata.extend(pm.sample_posterior_predictive(idata, extend_inferencedata=True, sample_dims=["chain", "draw", "pred_id"], random_seed=0))

Now make the forest plot with this code block:

import arviz as az
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator

def make_forest_plot(idata, title):
    post_pred = az.extract(idata, group="posterior_predictive", combined=False)
    post_pred = post_pred.median(dim=["pred_id"])
    # post_pred_df = post_pred.to_dataframe().droplevel(["chain", "draw"]).reset_index()
   
    fig, ax = plt.subplots(figsize=(11.5, 5))
    az.plot_forest(
        post_pred,
        kind="forestplot",
        linewidth=4,
        combined=True,
        ridgeplot_overlap=1.5,
        ax=ax
    );
    ax.set_title(title);
    ax.xaxis.set_major_locator(MultipleLocator(5));
    ax.xaxis.set_minor_locator(MultipleLocator(1));
    ax.grid(True, axis="x", which="both");

make_forest_plot(idata, "Each Group's 94% HDI and IQR for its Median Value")

This produces the following plot:


Now produce a BPV ridgeline plot with this code block:

n_categories = len(categories)
fig = plt.figure(figsize=(11.5, 0.75*n_categories))
first_ax = None
for i, category in enumerate(categories):
    if first_ax is None:
        ax = fig.add_subplot(n_categories, 1, i + 1, frameon=False)
        first_ax = ax
    else:
        ax = fig.add_subplot(n_categories, 1, i + 1, frameon=False, sharex=first_ax)
    az.plot_bpv(idata.sel(groups_flat=category), kind="t_stat", t_stat="median", ax=ax);
    ax.set_ylabel(category, rotation=0);
    ax.set_title("");
    ax.xaxis.set_major_locator(MultipleLocator(5));
    ax.xaxis.set_minor_locator(MultipleLocator(1));
    ax.grid(True, axis="x", which="both");
   
    if i < n_categories - 1:
        plt.tick_params('x', labelbottom=False)
   
    if i == 0:
        ax.set_title("Each Group's Bayesian P Value for its Median Value");