How do I pull a subset of an ARVIZ array within the inference data?


I ran a prior predictive check with a model that is indexed on 800 different groups. That said, I only want to pull 10 of those groups to show the subset. Currently, running the following code yields the following graph.

with pooled_model:
    # prior_checks = pymc.sampling_jax.sample_numpyro_nuts(tune=1000, chains = 4, target_accept=0.9)
    prior_checks = pm.sample_prior_predictive()

_, ax = plt.subplots()
prior_checks .plot.scatter(x="ITEM_NUMBER", y="a", color="k", alpha=0.2, ax=ax)
plt.title("Prior Check")
ax.set_ylabel("Mean log Eaches Sold");

I looked in the ARVIZ docs and only found were to pull a subset of samples and not items. Has anyone done this before?

If you set up coords in your model, all the arviz plotting functions have a coords argument you can pass a dictionary to. So if in your model’s coords dictionary you have a key-value pair groups:[1, 2, 3, ... , 800], you can pass coords = {'groups':[5, 8, 20]} to only plot data for groups 5, 8, and 20.

There’s an example of this in the docs here if you scroll down a bit.

1 Like

I’m not sure what ‘constant_data’ is in the example…

obs_county = data.posterior["County"][data.constant_data["county_idx"]]
data = data.assign_coords(obs_id=obs_county, groups="observed_vars")
az.plot_ppc(data, coords={'obs_id': ['ANOKA', 'BELTRAMI']}, flatten=[])

Here is my inference data:

So I tried:
obs_item = prior_checks.prior['ITEM_NUMBER'][prior_checks.constant_data['item_idx']] but that didn’t work.

For reference, here is my model.

coords = {"ITEM_NUMBER":list(items_loc)}
with pm.Model(coords=coords, rng_seeder=RANDOM_SEED) as pooled_model:
    store_idx = pm.Data('item_idx',items, dims="obs_id", mutable=True)
    a = pm.Normal("a", 0.0, sigma=10.0, dims="ITEM_NUMBER")

    theta = a[store_idx]
    sigma = pm.HalfCauchy("error", 0.5)

    y = pm.Normal("y", theta, sigma=sigma, observed=training_data['eaches'], dims="obs_id")

Looks like the first 2 lines in the example are just setting up the coords, which you have already done. So skip to the 3rd line, and use the .plot(..., coords={'ITEM_NUMBER':['100179', '100186', ...] part.

Getting this error:

AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_2168/ in <module>
      1 _, ax = plt.subplots()
----> 2 prior_checks.prior.plot.scatter(x="ITEM_NUMBER", y="a", color="k", alpha=0.2, ax=ax, coords = {'ITEM_NUMBER':['100179', '100186', 'WHU9922']}, flatten=[])
      3 plt.xticks(rotation=45)
      5 ax.set_ylabel("Mean log Sales(eaches)");

/opt/conda/lib/python3.7/site-packages/xarray/plot/ in plotmethod(_PlotMethods_obj, x, y, u, v, hue, hue_style, col, row, ax, figsize, col_wrap, sharex, sharey, aspect, size, subplot_kws, add_guide, cbar_kwargs, cbar_ax, vmin, vmax, norm, infer_intervals, center, levels, robust, colors, extend, cmap, **kwargs)
    470         for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]:
    471             del allargs[arg]
--> 472         return newplotfunc(**allargs)
    474     # Add to class _PlotMethods

/opt/conda/lib/python3.7/site-packages/xarray/plot/ in newplotfunc(ds, x, y, u, v, hue, hue_style, col, row, ax, figsize, size, col_wrap, sharex, sharey, aspect, subplot_kws, add_guide, cbar_kwargs, cbar_ax, vmin, vmax, norm, infer_intervals, center, levels, robust, colors, extend, cmap, **kwargs)
    385             hue_style=hue_style,
    386             cmap_params=cmap_params_subset,
--> 387             **kwargs,
    388         )

/opt/conda/lib/python3.7/site-packages/xarray/plot/ in scatter(ds, x, y, ax, **kwargs)
    534         primitive = ax.scatter(
--> 535             data["x"].values.ravel(), data["y"].values.ravel(), **cmap_params, **kwargs
    536         )

/opt/conda/lib/python3.7/site-packages/matplotlib/ in inner(ax, data, *args, **kwargs)
   1410     def inner(ax, *args, data=None, **kwargs):
   1411         if data is None:
-> 1412             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1414         bound = new_sig.bind(ax, *args, **kwargs)

/opt/conda/lib/python3.7/site-packages/matplotlib/axes/ in scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, edgecolors, plotnonfinite, **kwargs)
   4466                 )
   4467         collection.set_transform(mtransforms.IdentityTransform())
-> 4468         collection.update(kwargs)
   4470         if colors is None:

/opt/conda/lib/python3.7/site-packages/matplotlib/ in update(self, props)
   1062                     func = getattr(self, f"set_{k}", None)
   1063                     if not callable(func):
-> 1064                         raise AttributeError(f"{type(self).__name__!r} object "
   1065                                              f"has no property {k!r}")
   1066                     ret.append(func(v))

AttributeError: 'PathCollection' object has no property 'coords'

I used coords in my model. I even see them correctly in my summary.


I assume you are using v4 from what you are saying. In v4 all PyMC sampling functions return InferenceData by default.

InferenceData is an ArviZ specific object designed to store the results of Bayesian models, however, the data is not stored in ArviZ specific arrays. Each group of the InferenceData is an xarray.Dataset, which is a container for multiple variables whose dimensions (and optionally coordinate values too) are labeled.

All dimensions in a dataset are equal and you can index along them using the .sel or .isel methods (depending on wanting label based or positional indexing respectively), the chain and draw dimensions are not special.

Therefore, assuming you have a list with the coordinate values of interest you want to plot, instead of

(where I assume a .prior["variable_name"] is missing) you need to select the subset of interest before plotting, for example:

# assuming items_of_interest variable has the coord values you want to plot
prior_checks.prior["var_name"].sel(ITEM_NUMBER=items_of_interest).plot.scatter(x="ITEM_NUMBER", y="a", color="k", alpha=0.2, ax=ax)
# or to plot the first 20 items, use isel for positional indexing
prior_checks.prior["var_name"].sel(ITEM_NUMBER=slice(None, 20)).plot.scatter(x="ITEM_NUMBER", y="a", color="k", alpha=0.2, ax=ax)

ArviZ functions provide the coords argument to make it more convenient to indicate a subset to be plotted from a single function call, but it internally calls the .sel method. And therefore you can also get the same result by subseting first with either sel or isel and then passing the result to an arviz function without using the coords argument.

1 Like

I’m using version 0.13.0.dev0. Maybe that is why it doesn’t recognize the “coord” call or your suggested use of sel'?

When I run:
prior_checks.prior["ITEM_NUMBER"].sel(ITEM_NUMBER='100179').plot.scatter(x="ITEM_NUMBER", y="a", color="k", alpha=0.2, ax=ax)

I get:

AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_2168/ in <module>
      1 _, ax = plt.subplots()
----> 2 prior_checks.prior["ITEM_NUMBER"].sel(ITEM_NUMBER='100179').plot.scatter(x="ITEM_NUMBER", y="a", color="k", alpha=0.2, ax=ax)
      5 plt.xticks(rotation=45)

AttributeError: '_PlotMethods' object has no attribute 'scatter'
` ``

This line of code

is not using ArviZ, this is pure xarray and therefore the ArviZ version is irrelevant. The groups in InferenceData are xarray.Dataset objects, once you select a group you can use any and all xarray features directly.

As for coords and sel, ArviZ uses an argument in functions called coords for subsetting, xarray uses the methods .sel and .isel.

And about the error, it looks like scatter is only a plotting method of Dataset not of DataArrays, so my assumption was wrong, your original plot had the .prior but not the ["var_name"], otherwise it would not have worked. Use

prior_checks.prior.sel(ITEM_NUMBER='100179').plot.scatter(x="ITEM_NUMBER", y="a", color="k", alpha=0.2, ax=ax)

Thank you.

I’m following this post but with my own data: A Primer on Bayesian Methods for Multilevel Modeling — PyMC documentation

The post makes sense in walking my audience through the process. However, I would like to do the same thing, only using ArviZ the correct way. I don’t see how to exactly duplicate these plots in the documentation but will continue to try and and improve the ArviZ way.

Thanks for the help.

The multilevel modeling post uses ArviZ the “correct way”. ArviZ is built on top of xarray, relies on it and tries to avoid duplicating functionality that is already available in xarray. Using xarray directly is very much encouraged and recommended; and done in that notebook, it uses xarray plotting, it uses xarray groupby capabilities, it uses fancy indexing and automatic broadcasting and alignment of arrays…

I know xarray can be hard to understand, probably ArviZ too, and it is important to know where docs are lacking. It is probably worth it to add a note on the Working with InferenceData page about how all dimensions can be subsetted with .sel or .isel not only chain or draw like the example

The only suggestion I’m having a hard time with is the constant_data used throughout this notebook.


hdi_helper = lambda ds: az.hdi(ds, input_core_dims=[["chain", "draw", "obs_id"]])
hdi_ppc = (

I found this in the docs…

InferenceData schema specification — ArviZ dev documentation

And while I see it in the inference objects shown in the ArviZ docs, I don’t see it in my inference object.