How to use ax argument of arviz.plot_dist_comparison?

Hi,

I only want to produce a single plot per variable - the combined plot of the prior and posterior distribution. How do I do this using arviz.plot_dist_comparison?
I believe it has something to do with the ax argument. Here’s a snippet of my code:

fig, ax_combined = plt.subplots(figsize=(12, 5))

for varname in var_names:
    # Plot only the combined plot, setting ax=None for before and after plots
    az.plot_dist_comparison(trace, var_names=varname, kind='latent', ax=(None, None, ax_combined))
    ax_combined.set_title(f"Posterior Density Plot for Parameter: {varname}")

plt.tight_layout()
plt.show()

I get the following error:

Running on PyMC v5.8.0
Traceback (most recent call last):
  File "/home/___.py", line 212, in <module>
    az.plot_dist_comparison(trace, var_names=varname, kind='latent', ax=(None, None, ax_combined))
  File "/home/***/anaconda3/envs/sunode-env/lib/python3.11/site-packages/arviz/plots/distcomparisonplot.py", line 195, in plot_dist_comparison
    axes = plot(**distcomparisonplot_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/***/anaconda3/envs/sunode-env/lib/python3.11/site-packages/arviz/plots/backends/matplotlib/distcomparisonplot.py", line 80, in plot_dist_comparison
    if ax.shape != (nvars, ngroups + 1):
       ^^^^^^^^
AttributeError: 'tuple' object has no attribute 'shape'

I’m not proficient in Python. Any help is greatly appreciated. Thanks.

1 Like

I think you are pretty much there. The tuple just needs to be an array.

This code will omit the first two plots generated by plot_dist_comparison() and just return the combined one.

fig, ax_combined = plt.subplots(figsize=(12, 5))

axes = az.plot_dist_comparison(trace,ax=np.array([[None,None,ax_combined]]))

Tuples don’t have shape methods. So this works:

np.array([[None,None,3]]).shape

But this doesn’t:

(None,None,3).shape

Thank you so much!!!

1 Like

If you don’t want the multiple subplot structure provided by plot_dist_comparison then you are probably better served by plot_density or plot_posterior.

plot_density for example has direct support for multiple datasets:

centered = az.load_arviz_data('centered_eight')
az.plot_density([centered.posterior, centered.prior], data_labels=["posterior", "prior"])

1 Like