I think I’m answering my own question here, so this is to share and get feedback (if any).
I’ve been wanting to use the
arviz.plot_forest() function to visualize a trace. However, the trace contains a number of variables, and for each variable, hundreds if not thousands of parameters.
For example, I have a trace that contains posteriors for a parameter for 156 products by 52 weeks. I wanted to create a forestplot for only one of those products.
My data is arranged row-wise, so I need to index into the correct product.
Here’s what I’ve come up with:
import numpy as np import arviz as az
weeks = [...array of 52 weeks...] products = [....array of 156 products...] product_week_idx = [...index of 8,112 product/week combinations] trace = [...PyMC3 trace with 6000 samples from 3 chains, 8112 parameters per variable] # e.g. # trace["my_trace_variable].shape # > (6000, 8112)
# find index of selected product in products array product_name = "The Product I Want to Plot" selected_product_index = np.where(products == product_name)
# find index of selected product/week combos in products x weeks array selected_product_week_index = np.where(product_week_idx == selected_product_index) # selected_product_week_index.shape # > (52,)
The key is to convert the relevant subset of the trace to an
xarray.Dataset before passing it to
trace_data = az.from_pymc3(trace=trace) plot_data = (product_week_trace_data .posterior["my_trace_variable"][:,:, selected_product_week_index] .to_dataset() )
_, axes = az.plot_forest(plot_data, credible_interval=0.95, combined=True) # reverse the week names for the plot axes.set_yticklabels(weeks[::-1]); axes.set_title(product_name)