Arvix plot_forest for Posterior Predictive Median Values?

I’d like to generate a plot which compares the dimensions in my model, similar to what plot_forest makes. But instead of representing the HDI across each dimension, I want to plot the HDI for each dimension’s median. Is there a better way of doing this than iterating over each dimension and plotting it with plot_bpv?

I don’t understand what this means. Could you clarify?

That being said, it is probable that using arviz_plots.plot_forest — arviz-plots dev documentation will allow that. There are keywords to change the type of credible intervals (hdi or eti), the probabilities of both of them and the point estimate. Plus you can also compute it manually and use stats_kwargs to provide pre-computed values.

If you decide to try it I’d recommend doing so from GitHub: arviz-plots — arviz-plots dev documentation

Thanks for following up @OriolAbril!

The following is a code example to show what I came up with. It turns out that plot_forest will do what I was hoping for. I was also able to make a sort of ridge plot out of BPV plots, although it isn’t pretty. I ended up doing all of this by sampling posterior predictive 100 times for each draw, allowing me to later calculate a median value for each draw. I’m not sure if this step was necessary or not…

Make the example data with this code block:

import numpy as np
import pandas as pd

n_groups = 10
n_obs_min, n_obs_max = 5, 50
group_mean_min, group_mean_max = 0, 30
group_std_min, group_std_max = 1, 10

# make the data
np.random.seed(0)
df_ls = []
for i in range(n_groups):
    n_obs = np.random.randint(n_obs_min, n_obs_max, 1)[0]
    mean = np.random.uniform(group_mean_min, group_mean_max, 1)[0]
    std = np.random.uniform(group_std_min, group_std_max, 1)[0]
    df_ls += [pd.DataFrame({"y": np.random.normal(mean, std, n_obs), "group": i})]
df = pd.concat(df_ls).reset_index()

Now fit a simple unpooled model:

import pymc as pm

# fit an unpooled model
data = df.y.values
categories = df.group.unique().tolist()
idx = pd.Categorical(df.group, categories=categories).codes
coords = {"groups": categories, "groups_flat": df.group.values}
with pm.Model(coords=coords) as unpooled_model:
    μ = pm.Normal("μ", mu=(group_mean_min + group_mean_max)/2, sigma=10, dims="groups")
    σ = pm.HalfNormal("σ", sigma=10, dims="groups")
    Y = pm.Normal("Y", mu=μ[idx], sigma=σ[idx], observed=data, dims="groups_flat")
   
    idata = pm.sample(random_seed=0)
    idata.posterior = idata.posterior.expand_dims(pred_id=100)
    idata.extend(pm.sample_posterior_predictive(idata, extend_inferencedata=True, sample_dims=["chain", "draw", "pred_id"], random_seed=0))

Now make the forest plot with this code block:

import arviz as az
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator

def make_forest_plot(idata, title):
    post_pred = az.extract(idata, group="posterior_predictive", combined=False)
    post_pred = post_pred.median(dim=["pred_id"])
    # post_pred_df = post_pred.to_dataframe().droplevel(["chain", "draw"]).reset_index()
   
    fig, ax = plt.subplots(figsize=(11.5, 5))
    az.plot_forest(
        post_pred,
        kind="forestplot",
        linewidth=4,
        combined=True,
        ridgeplot_overlap=1.5,
        ax=ax
    );
    ax.set_title(title);
    ax.xaxis.set_major_locator(MultipleLocator(5));
    ax.xaxis.set_minor_locator(MultipleLocator(1));
    ax.grid(True, axis="x", which="both");

make_forest_plot(idata, "Each Group's 94% HDI and IQR for its Median Value")

This produces the following plot:


Now produce a BPV ridgeline plot with this code block:

n_categories = len(categories)
fig = plt.figure(figsize=(11.5, 0.75*n_categories))
first_ax = None
for i, category in enumerate(categories):
    if first_ax is None:
        ax = fig.add_subplot(n_categories, 1, i + 1, frameon=False)
        first_ax = ax
    else:
        ax = fig.add_subplot(n_categories, 1, i + 1, frameon=False, sharex=first_ax)
    az.plot_bpv(idata.sel(groups_flat=category), kind="t_stat", t_stat="median", ax=ax);
    ax.set_ylabel(category, rotation=0);
    ax.set_title("");
    ax.xaxis.set_major_locator(MultipleLocator(5));
    ax.xaxis.set_minor_locator(MultipleLocator(1));
    ax.grid(True, axis="x", which="both");
   
    if i < n_categories - 1:
        plt.tick_params('x', labelbottom=False)
   
    if i == 0:
        ax.set_title("Each Group's Bayesian P Value for its Median Value");