I think I’m answering my own question here, so this is to share and get feedback (if any).
I’ve been wanting to use the arviz.plot_forest()
function to visualize a trace. However, the trace contains a number of variables, and for each variable, hundreds if not thousands of parameters.
For example, I have a trace that contains posteriors for a parameter for 156 products by 52 weeks. I wanted to create a forestplot for only one of those products.
My data is arranged row-wise, so I need to index into the correct product.
Here’s what I’ve come up with:
import numpy as np
import arviz as az
weeks = [...array of 52 weeks...]
products = [....array of 156 products...]
product_week_idx = [...index of 8,112 product/week combinations]
trace = [...PyMC3 trace with 6000 samples from 3 chains, 8112 parameters per variable]
# e.g.
# trace["my_trace_variable].shape
# > (6000, 8112)
# find index of selected product in products array
product_name = "The Product I Want to Plot"
selected_product_index = np.where(products == product_name)[0][0]
# find index of selected product/week combos in products x weeks array
selected_product_week_index = np.where(product_week_idx == selected_product_index)[0]
# selected_product_week_index.shape
# > (52,)
The key is to convert the relevant subset of the trace to an xarray.Dataset
before passing it to forestplot
:
trace_data = az.from_pymc3(trace=trace)
plot_data = (product_week_trace_data
.posterior["my_trace_variable"][:,:, selected_product_week_index]
.to_dataset()
)
_, axes = az.plot_forest(plot_data,
credible_interval=0.95,
combined=True)
# reverse the week names for the plot
axes[0].set_yticklabels(weeks[::-1]);
axes[0].set_title(product_name)