Hello,
Could someone verify how to pull 95% confidence intervals correctly out of an arviz object and display them correctly? The docs are clear on the mean but I’m not sure if I’m doing it correctly on the quantiles.
Here is how I’m pulling it:
fig = plt.figure(figsize=(15,4))
sns.lineplot(x =df_sat_test['t'], y = np.array(test_ppc.posterior_predictive.obs.mean(dim=['chain', 'draw'])), label = 'Posterior Predictive')
sns.lineplot(x =df_sat_test['t'], y = np.array(test_ppc.posterior_predictive.obs.quantile(q=.025, dim=['chain', 'draw'])), label = 'Confidence', color = 'r')
sns.lineplot(x =df_sat_test['t'], y = np.array(test_ppc.posterior_predictive.obs.quantile(q=.975, dim=['chain', 'draw'])), label = 'Confidence', color = 'r')
sns.scatterplot(x =df_sat_test['t'], y = np.array(df_sat_test['eaches']), label = 'True Value')
plt.legend()
This yields:
I’m trying to pull the 95% correctly and display with it filled between and transparent like the beautiful arviz graphics that are standard.
Using xarray’s quantile()
looks correct to me. Just to be precise, these are (Bayesian) credible intervals, not (frequentist) confidence intervals. Alternatively, you can use the arviz hdi()
function to get the highest density interval (i.e., the smallest credible interval).
2 Likes
Thanks. Understood they are not frequentist intervals. My bad for using the wrong terminology.
Figured out the syntax:
fig = plt.figure(figsize=(15,4))
a=sns.lineplot(x =df_sat_test['t'], y = np.array(test_ppc.posterior_predictive.obs.mean(dim=['chain', 'draw'])), label = 'Posterior Prediction', color = 'red')
b=sns.lineplot(x =df_sat_test['t'], y = np.array(test_ppc.posterior_predictive.obs.quantile(q=.025, dim=['chain', 'draw'])), label = 'Confidence',
color = 'skyblue', alpha=.3)
c=sns.lineplot(x =df_sat_test['t'], y = np.array(test_ppc.posterior_predictive.obs.quantile(q=.975, dim=['chain', 'draw'])), label = 'Confidence',
color = 'skyblue', alpha=.3)
line = c.get_lines()
plt.fill_between(line[0].get_xdata(), line[1].get_ydata(), line[2].get_ydata(), color='skyblue', alpha=.3)
sns.scatterplot(x =df_sat_test['t'], y = np.array(test_ppc.observed_data.obs), label = 'True Value', color='black')
plt.legend()
1 Like