Argument for pm.plot_hdi()

Hi all,
I’m trying to translate a code in PyMC3 into PyMC5 but the code below doesn’t work.
So if anyone know what to do, let me know.

My PyMC version is 5.16.2 and python version is 3.11.6

import numpy as np
import pymc as pm
import matplotlib.pyplot as plt
import arviz as az

n = 30
x = np.zeros(n)
y = np.zeros(n, dtype=int)

Dist_s = [2.4, 2.8]
Dist_w = [0.8, 1.6]

rng = np.random.default_rng(10)
for i in range(n):
    wk = rng.random()
    y[i] = 0 * (wk < 0.5) + 1 * (wk >= 0.5)
    x[i] = rng.random() * Dist_w[y[i]] + Dist_s[y[i]]
    
data = {'x': x, 'y': y}

# centerize the data
data['x_c'] = data['x'] - data['x'].mean()

with pm.Model() as model_l:
    beta0 = pm.Normal('beta0', mu=0, sigma=100)
    beta1 = pm.Normal('beta1', mu=0, sigma=100)
    
    mu = pm.Deterministic('mu', pm.math.sigmoid(beta0 + beta1 * data['x_c']))
    boundary = pm.Deterministic('boundary', -beta0 / beta1)
    
    pm.Bernoulli('y', p=mu, observed=data['y'])
    
    trace_l = pm.sample(random_seed=1)

trace_l_ext = az.extract(trace_l)

fig, ax = plt.subplots(constrained_layout=True)

ax.scatter(data['x'], data['y'], color=[f'C{i}' for i in data['y']])

mu = trace_l_ext['mu'].mean(axis=0)
idx = np.argsort(data['x'])
ax.plot(data['x'][idx], mu[idx])
# pm.plot_hdi(data['x'], trace_l['mu'], ax=ax) # for pymc3
pm.plot_hdi(data['x'], trace_l_ext['mu'], ax=ax) # doesn't work
# pm.plot_hdi(data['x'], pp.posterior_predictive['mu'], ax=ax) # doesn't work

plt.vlines(trace_l_ext['boundary'].mean(axis=0) + data['x'].mean(), 0, 1, color='k')
boundary_hdi = pm.hdi(trace_l_ext['boundary']) + data['x'].mean()
ax.fill_betweenx([0, 1], boundary_hdi[0], boundary_hdi[1], alpha=0.5)
ax.set_xlabel(r'$x$')
ax.set_ylabel(r'$y$')

The problem is pm.plot_hdi().
In version 3, pm.plot_hdi(data['x'], trace_l['mu'], ax=ax) worked, but I don’t know what to input to the 2nd argument in version 5.
Could anyone help me?

The documentation for arviz.plot_hdi() is here. Without seeing the error you are encountering, it’s difficult to diagnose. But I suspect that you need to select parameter values from trace_l.posterior instead of trace_l.

@cluhmann Thank you for your accurate comment.
My code now works fine!

1 Like