How to forestplot in arviz with only a subset of trace variables and samples

I think I’m answering my own question here, so this is to share and get feedback (if any).

I’ve been wanting to use the arviz.plot_forest() function to visualize a trace. However, the trace contains a number of variables, and for each variable, hundreds if not thousands of parameters.

For example, I have a trace that contains posteriors for a parameter for 156 products by 52 weeks. I wanted to create a forestplot for only one of those products.
My data is arranged row-wise, so I need to index into the correct product.

Here’s what I’ve come up with:

import numpy as np
import arviz as az
weeks = [...array of 52 weeks...]
products = [....array of 156 products...]
product_week_idx = [...index of 8,112 product/week combinations]
trace = [...PyMC3 trace with 6000 samples from 3 chains, 8112 parameters per variable]

# e.g. 
# trace["my_trace_variable].shape
# > (6000, 8112)
# find index of selected product in products array
product_name = "The Product I Want to Plot"
selected_product_index = np.where(products == product_name)[0][0]
# find index of selected product/week combos in products x weeks array
selected_product_week_index = np.where(product_week_idx == selected_product_index)[0]
# selected_product_week_index.shape
# > (52,)

The key is to convert the relevant subset of the trace to an xarray.Dataset before passing it to forestplot:

trace_data = az.from_pymc3(trace=trace)

plot_data = (product_week_trace_data
             .posterior["my_trace_variable"][:,:, selected_product_week_index]

_, axes = az.plot_forest(plot_data, 
# reverse the week names for the plot
1 Like

If you use the dims and coords kwargs in az.from_pymc3 you could get rid of all that index juggling. You can then plot using something like

plot_data = (product_week_trace_data

The .to_dataset is a nice method I hadn’t used.

I was hoping you could use named coordinates and dimensions in a nice way here: it is difficult because I don’t think xarray supports compound indices, and it seems like that is what you want with product_to_index.

Hm, seems I didn’t read the question properly :slight_smile:
How about something like this?

ds = xr.Dataset({
    'item': ('item', ['item1', 'item2', 'item3']),
    'week': ['week1', 'week2', 'week3'],
    'product': ('product', ['bar', 'foo']),

    'item_product': ('item', ['bar', 'foo', 'bar']),
    'item_week': ('item', ['week1', 'week2', 'week1']),
    'item_data': ('item', [1, 2, 3]),
Dimensions:       (item: 3, product: 2, week: 3)
  * item          (item) <U5 'item1' 'item2' 'item3'
  * week          (week) <U5 'week1' 'week2' 'week3'
  * product       (product) <U3 'bar' 'foo'
Data variables:
    item_product  (item) <U3 'bar' 'foo' 'bar'
    item_week     (item) <U5 'week1' 'week2' 'week1'
    item_data     (item) int64 1 2 3
Dimensions:       (item: 2, product: 2, week: 3)
  * item          (item) <U5 'item1' 'item3'
  * week          (week) <U5 'week1' 'week2' 'week3'
  * product       (product) <U3 'bar' 'foo'
Data variables:
    item_product  (item) <U3 'bar' 'bar'
    item_week     (item) <U5 'week1' 'week1'
    item_data     (item) int64 1 3


You can also tell xarray that item is a multiindex thing:

ds.set_index(item=['item_product', 'item_week'])

I’m not too convinced by the current Multiindex support in xarray though, so maybe it is better to do this manually for now.

Nice! I’ll give that a shot. Thx!