Control the order of coordinates in arviz.InferenceData

I have an arviz.InferenceData:

In [27]: inf = arviz.load_arviz_data('non_centered_eight')

In [28]: inf.posterior.coords
Out[28]: 
Coordinates:
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'

I want to rearrange the schools to be in custom order. The goal is to control the order in which they appear on the plots.

I tried simply sorting them in place,

inf.posterior.coords['school'] = inf.posterior.coords['school'].sortby(
    inf.posterior.coords['school']
)

but I think this simply rearranges the labels, giving the wrong labels for each estimate!

How can I change the order of the coordinates in my plots, while retaining the correct association of labels (school name) to school estimates?

In simple cases, I could just rearrange them before running the estimation. The problem is that eventually I want to sort the schools in order of estimated mean value, so it has to happen after estimation.

1 Like

Hi, good question. I don’t have any certain answer, but some possible solutions

You might want try to use xarray.Dataset (e.g. inf.posterior) sorting methods (with suitable function if possible).

Another way to solve this, might be using slicing with .sel/.isel, this can be against inf or inf group. Not sure if this will work.

I’d suggest trying to ask the xarray devs on StackOverflow. Maybe you can find an answer there. If you do, would you respond to this and put a link to the StackOverflow question so that other Arviz users can find an answer here?

Is there any chance that reorder_levels will do what you want?

https://xarray-test.readthedocs.io/en/latest/generated/xarray.DataArray.reorder_levels.html

I’m not sure that’s right, because the documentation there doesn’t have an example, and it’s full of terms and concepts (like the relationship between DataArray's and multi-indexes) that I don’t understand. But it might be worth experimenting with it. Good luck!

Oh, yes, another issue is that the coordinates are a property of the Dataset, but this only modifies a component Data array…

Perhaps you could just specify the order that you wish in the plotting functions themselves. I believe that should achieve what you wish.

with pm.Model() as model:
    a = pm.Normal('a', 0., 1.) # group mean
    sigma = pm.Exponential('sigma', 1.) # determines amount of shrinkage

    a_cluster = pm.Normal('a_cluster', a, sigma, shape=2)
    p = pm.math.invlogit(a_cluster[[0, 1]])

    pm.Binomial('obs', p=p, n=[6, 125], observed=[1, 110])
    
    trace = pm.sample(target_accept=0.8)
    ppc = pm.sample_posterior_predictive(trace)
    idata = az.from_pymc3(trace, posterior_predictive=ppc, dims={'a_cluster':['ac_custom']}, coords={'ac_custom':['ac_0', 'ac_1']})

1 Like

You can do this by manipulating the idata.posterior object, which is an xarray Dataset. If you only care about changing the order in which the plots appear, then you can index the data based on school using the Dataset.sel method.

Here’s an example where I select schools in the opposite order (sorry about the missing cells, I’m sure no one wants to see my typos!):

In [1]: import arviz as az

In [2]: inf = az.load_arviz_data("non_centered_eight")

In [3]: inf.posterior
Out[3]:
<xarray.Dataset>
Dimensions:  (chain: 4, draw: 500, school: 8)
Coordinates:
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    mu       (chain, draw) float64 ...
    theta_t  (chain, draw, school) float64 ...
    tau      (chain, draw) float64 ...
    theta    (chain, draw, school) float64 ...
Attributes:
    created_at:                 2019-06-21T17:36:37.382566
    inference_library:          pymc3
    inference_library_version:  3.7

In [4]: _ = az.plot_trace(inf.posterior, var_names=["theta"])
# will plot image 1 below

In [6]: order = inf.posterior.school.values[::-1]

In [7]: order
Out[7]:
array(['Mt. Hermon', "St. Paul's", 'Lawrenceville', 'Hotchkiss',
       'Phillips Exeter', 'Phillips Andover', 'Deerfield', 'Choate'],
      dtype=object)

In [9]: _ = az.plot_trace(inf.posterior.sel(school=order), var_names=["theta"])
# will plot image 2 below

Cell [4] plots this:

Cell [9] plots this:

In general, arviz integrates very nicely with xarray—both are, in my experience, very well-designed and quite flexible, so hats off to the developers of both :slight_smile: If you want to plot subsets of your posterior, to rename the labels, reorder things, etc., you can usually do it by manipulating the underlying xarray Dataset and passing that straight to arviz.

With that said, this was challenging to do at first—mostly because learning the language of xarray Datasets and the new abstractions they use took some trial and error. I think this might be a useful example to add to the arviz docs to highlight how you can do some operations in xarray to make your plotting life easier.

3 Likes

Thanks for all the great answers!

plot_forest(inferencedata.posterior.sortby(['school'])) seems to be another option.

The question has already been answered, showing also that there are different options available. I feel that this is quite relevant and will probably be asked again in the future, so I thought I would summarize the different options above and classify them depending on what they achieve:

  • Sort using coordinate values and/or variable names
    • Use var_names and coords argument from plotting ee @nkaimcaudle’s answer. As far as I know, this is the only way to reorder data_variables.
    • Use sel method (either on xarray Dataset (idata.posterior.sel()) or inference data object itself (idata.sel()) as described in @tushar_chandra’s answer.
    • Use sortby() method (only present for xarray objects) as pointed out by @mina. This is probably what provides more flexibility. See also xarray docs
  • Sort using data variable values (for example sorting from lowest to highest posterior mean). Detailed below.

Plotting follows the order of the xarray object used, the initial order is the following:

idata = az.load_arviz_data("non_centered_eight")
az.plot_posterior(idata, var_names="theta")

we can use sortby to sort using a custom order, not necessarily to sort alphabetically using coordinate values. It automatically aligns and broadcasts inputs to sort along the desired dimension. Below is one example sorting by theta posterior means:

theta_post_means = idata.posterior.theta.mean(dim=("chain", "draw"))
# note the idata.posterior.theta, we need to specify the variable
idata.posterior = idata.posterior.sortby(theta_post_means)
az.plot_posterior(idata, var_names="theta")

Note 1: All idata.posterior = <xarray methods> followed by az.plot_xyz(idata, ...) could also be az.plot_xyz(<xarray methods>, ...) and viceversa. One or the other may be more adequate depending on the situation. To create many different plots with the same order, the first option is probably better as sorting is performed once and idata is overwritten with sorted results. To perform several plots each of them having a different order, the second option is probably better.

Note 2: There are some limitations to all approaches described here, the most prominent that comes to mind is that variables sharing a dimension will also be shown in the same order. In the example above, if the sort according to theta posterior means, theta_t is also reordered because they share the school dimension.

5 Likes

I thought people following this thread may be interested in the new labeling capabilities in ArviZ. The feature is only available in the development version for now, and will be included with next release. Here are the docs on what is now possible to customize regarding labels in ArviZ plotting and stats like az.summary: Label guide — ArviZ dev documentation. It also has guidance on how to sort labels, dimensions and coordinate values.

1 Like

These xarray methods are fantastic, thanks for the clear explanation too. Just saved me several messy LOC!

Thanks for the new docs too!

1 Like