Unexpected behavior with arviz.plot_hdi() with categorical x

Hi all!

As a big Stan and brms user, I am really enjoying learning to fit Bayesian models in Python with the PyMC ecosystem.

While attempting to refit models from my own work in bambi, I ran into some trouble plotting posterior expectations and 95% HDIs from an ANOVA model. Here’s a simple simulation to demonstrate the issue

import arviz as az
import bambi as bmb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Simulate data
np.random.seed(42)
x = ['A', 'B', 'C']
yA = np.random.normal(loc=5, scale=3, size=30)
yB = np.random.normal(loc=2, scale=4, size=30)
yC = np.random.normal(loc=7, scale=1.8, size=30)

# Create a DataFrame
data = pd.DataFrame({
    'y': np.concatenate([yA, yB, yC]),
    'group': np.repeat(x, repeats=30)
})
data['group'] = data['group'].astype('category')

# Plot the data
plt.figure(figsize=(8, 6))
sns.boxplot(x='group', y='y', data=data, palette='Set3')
sns.stripplot(x='group', y='y', data=data, color='black', alpha=0.5, jitter=True)
plt.title('Distribution of y across groups')
plt.xlabel('Group')
plt.ylabel('y')
plt.show()

# Fit a Bayesian ANOVA model using Bambi
model = bmb.Model('y ~ group', data)
idata = model.fit()

preds = model.predict(idata, kind="response_params", inplace=False)
y_mu = az.extract(preds["posterior"])["mu"].values
group = data.loc[:, "group"].values

Now I want to use az.plot_hdi() to plot the expectation of the posterior distribution that I extracted (in preds above).

Following the documention for arviz, I first tried:


az.plot_hdi(x=group, y=y_mu.T)

#UFuncTypeError: ufunc 'multiply' did not contain a loop with signature matching types (dtype('<U1'), dtype('float64')) -> None

It turns out that az.plot_hdi has a parameter called smooth that is True by default. In this case, the function attempts to interpolate across all x, which is not possible for numpy to do when x is categorical. Setting smooth to False works around the error, but does not produce the plot I’d expect it to (interpolating the HDI across groups instead of plotting, say, segments for the HDI region):

az.plot_hdi(x=group, y=y_mu.T, smooth=False)

Am I missing something with respect to how az.plot_hdi() is intended to behave? The documentation doesn’t seem to suggest the function should only work if x is numeric, but that seems to be the behavior? Perhaps there’s a plot_kwarg that I should be passing through to matplotlib to get a more fitted ANOVA like plot?

If this is not expected, I am happy to try and figure out a solution and contribute to the repo. I wanted to open up a discussion before a github issue though.

1 Like

Hi and welcome!

I don’t think it’s possible to modify the behavior of arviz.plot_hdi() without modifying its internals. However, you can still do what you want with the interpret sub-package in Bambi.

bmb.interpret.plot_predictions(
    model=model,
    idata=idata,
    conditional="group",
);

That displays the 94% HDI for the mean parameter. If you want it for the predictive distribution, you can do

bmb.interpret.plot_predictions(
    model=model,
    idata=idata,
    conditional={"group": ["A", "B", "C"]},
    pps=True
);

You can see more examples in Tools to interpret model outputs.


It should be possible to map group levels to colors but there’s a bug right now. I’m opening an issue in the Bambi repository to fix it.

Edit issue here bambi.interpret_plot_predictions() fails when we condition and color by the same categorical variable · Issue #870 · bambinos/bambi

Thank you, @tcapretto! This makes sense to me. Peeking at the function definition, I see what you mean about needing to modify internals. Thank you for providing an alternative solution using the interpret sub module. I will play around some with this! I think I should do some reading about plotting models fit with PyMC and bambi. I am much less comfortable here than I am with brms and bayesplot’s add_pred_draws() and ggplot2 / ggdist workflow in R. Though I am excited about how lively the PyMC community is!

If you’d like, I can open an issue on arviz regarding the az.plot_hdi function. I think a ValueError if categorical or str type x’s are supplied would be more informative than the error that is currently returned from np.interpolate when this is the case (when smooth=True). This is something I am more than happy to submit a PR for as well.

-Miles


Edit: Issue and PR created: [WIP] `plot_hdi` raise exception when `x` is string (#2412) by milesalanmoore · Pull Request #2413 · arviz-devs/arviz · GitHub

2 Likes

11 posts were split to a new topic: User experience: Python vs R, PyMC vs Stan vs JAX