How to get a az.plot_trace() legend of multi-dimensional β distribution to show dims and not chains

I have hierarchical model and I’m trying to get the legend in a az.plot_tract() to show the coords variable names, and not the chains.

az.plot_trace(idata, var_names=['β'],compact = True, combined = True, legend = True, coords={'beta_list': ['Dim1']},figsize = (14,6))

where:

β = pm.Normal("β", mu = mu_β, sigma = sigma_β, dims = ('group','beta_list'))   

No matter what I try the legend only show’s the chains.

Any thoughts would be greatly appreciated. Thank you!

There is a bug in how ArviZ decided to place the multiple legends involved in plot_trace.

You can reproduce your issue with:

import arviz as az
idata = az.load_arviz_data("rugby")
az.plot_trace(idata, var_names="atts", compact=True, combined=True, legend=True, show=True)

where not team legend is added. One possible workaround is using a scalar variable in the model to generate an extra axes where ArviZ will place the chain one, so in doing that it doesn’t overwrite the team legend.

az.plot_trace(idata, var_names=["home", "atts"], compact=True, combined=True, legend=True, show=True)

so the chain legend is placed in the home variable and then the team legend is correctly added. It is also possible to generate the single plot with something llike:

_, ax1 = plt.subplots(1, 2)
_, ax2 = plt.subplots(1, 2)
ax = np.vstack((ax1, ax2))
az.plot_trace(idata, var_names=["atts", "home"], compact=True, combined=True, legend=True, axes=ax, show=True)

Note the variable order is key to get the desired results and work around the bug.

3 Likes

Thank you so much for explanation and work-around. Side note - I really enjoyed your ArviZ in depth blog entries. Cheers!

1 Like