Combining dimensions in Arviz plots

I have a distribution that has a large dimension to the observed variable, obs. This large dimension, "rows", corresponds to different settings of independent variables, so I often want to plot the distribution of obs combined along some subset of the rows. plot_dist wants to plot each row separately.

az.plot_dist doesn’t support a coords argument, and anyway, I don’t want to get a large number of individual plots: I want to pool them.

I’m embarrassed to admit that I have no idea how to flatten out a subset of rows. One possibility, it seems, would be to treat each coordinate of obs as a separate set of chains. But the xarray documentation does not obviously provide a solution to the problem of reshaping the dimensions in this way. I considered reshaping the underlying numpy ndarray, but then it seemed like I would have to go to a lot of trouble to reconstruct the Dataset and InferenceData around the reshaped numpy array.

Stacking doesn’t seem like the right thing, either.

I can probably just do what I want in pandas, but I’d love to know how to do this in ArviZ – it seems like the sort of thing that one would want to do pretty frequently.

1 Like

This sounds like an interesting question but I’m having trouble wrapping my head around a concrete case. Could you post a code example or something else that would help recreate the problem?

Off the top of my head, I think you’d want to use the .sel method on the rows dimension of the trace to subset to the combinations of the independent variables you need. Then you would .stack the rows and chain dimensions together into a new dimension.

It might look like trace.sel(rows=...).stack({pooled_chains: ["rows", "chain"]}) or something similar. This would give you a pooled_chains dimension made up of the regular chain and the subsetted rows.

You’d probably want to rename the chain to something else, and pooled_chains to chain, too. That’d just be so arviz knew which dimension to treat as chains. I think .rename_dims would be enough, but you might need to call it twice.

I hope this is useful to you. If not, a code sample would go a long way in helping us to answer your question. Best of luck! :slight_smile:

1 Like

From what I understand, I think the ideal solution for that would be to extend the flatten arguments in plot_ppc to other plots so that you can modify the coordinates to encode the grouping/pooling information and then have ArviZ handle axis generation, grouping and plotting all at the same time. I say this would be ideal because like it happens with plot_ppc (see 2nd nd 3rd examples starting from the bottom in arviz.plot_ppc — ArviZ dev documentation) this approach will work out of the box and even with ragged groups.

Another not so ideal solution would be to reshape the row dimension into multiple dimensions and then groupby (flatten would also work here but if flatten is available I think the first option is the way to go). I think ArviZ functions may work with dataset groupbys.

Thank you. Sorry to have taken so long to get back to you.

The values I need to subset on are inputs – i.e., values of a data variable in the constant_data group, so as far as I can tell, I need to use where, rather than sel – I couldn’t figure out a satisfactory way to transform these into dimensions:

filtered = idata.posterior.where(np.logical_and(idata.constant_data['input'] == 3, idata.constant_data['temperature'] == 37))

(Maybe there’s a less ugly way to do the logical and?)

After some experimentation, I get to this:

filtered = (
    idata.posterior.where(np.logical_and(idata.constant_data['input'] == 3, idata.constant_data['temperature'] == 37))
    .rename({'chain': 'sub_chain'})
    .stack({'pooled_chain': ["rows", "sub_chain"]})
    .rename({'pooled_chain': 'chain'})

But then I get this:

az.plot_density(filtered, var_names=['mu_obs'])
arviz/stats/ UserWarning: Something failed: `x` does not contain any finite number.
  warnings.warn("Something failed: " + str(e))

… and I see an empty plot. Unfortunately, this is just a warning, not an error, so debugging it will be tricky. Any suggestions? mu_obs has some nan values in it, but is not all nans. Maybe it’s because some of the chains are entirely nan, and I need to somehow filter the chains? I don’t understand how to squash out coordinates that correspond to empty values in mu_abs.

Haha! Triumph! The final piece of the picture:

az.plot_density(filtered.dropna('chain', subset=['mu_obs']), var_names=['mu_obs'])

@OriolAbril Would you have a quick peek at my code and see if it corresponds to a normal thing to do, or if what I am trying to do here is odd, since I am not familiar with normal practice?

A question for ArviZ – is it normal to want to do this kind of rearrangement and selection when plotting data? Or is this so unusual, that it does not matter that it is difficult? Or is there something different one should do in the model structure in this situation? Should I, perhaps, reshape the variables (so that instead of having a mu_obs that I select values for, I split into different pseudo-variables for different conditions), so that I don’t need to do this kind of multi-stage filtering.

Now that I know how to do this, I can make it a kind of idiom, but there is no easy path to learn how to do this kind of manipulation at the moment.

Also plot_density may be the easiest, most flexible plotter to do this to. Other plots that cross variables, or combine information in different groups, might be more difficult.

It clearly looks like the flatten argument has to be extended to other plots. Having to rename+stack+rename dimensions to chain in order to get them to get flattened and use all the necessary data in the KDE should not be needed.

The where may be too verbose (xarray is already quite verbose in general too, starting from having a single digit converted to temperature, and having the idata.constant_data does not help), but looks fine to me. I think that & will work as logical and and shorten it a bit, but it is what it is.

If it’s not too much to ask, can you share the output of printing the posterior and constant_data groups? I think that on the xarray side the where is the way to go, and the rename+stack are needed due to limitations is ArviZ, but it will help me get a better picture and I try and see if flatten works.

I’m afraid not – & gets turned into bitwise-AND, instead of logical and. Google shows me this when I search for “numpy logical operator infix”: Infix operators for numpy arrays « Python recipes « ActiveState Code

Unfortunately, xarray does not (or didn’t when I checked last) have the equivalent of pandas’s query(), which provides nicer syntax for this.

I will confess that my response to the above was “too long; didn’t read.”

This would be hugely helpful, I think.

It would be nice if somehow I could simply treat the temperature and input variables as dimensions, but I don’t know how to do that (especially since they are in the constant_data group, not the posterior group). Intuitively, it seems like the constant_data should be usable as dimensions/coordinates. Maybe supporting that on InferenceData (both for constant_data/posterior or predictions_constant_data/predictions), would be helpful?

I digressed a little while writing this so here is a TL:DR:

There are 2 main issues contributing to all this unnecessary complexity.

  1. Xarray does not support indexing with coordinate values. Even if we moved them to the posterior dataset on ArviZ side, we would still have to use where. Hopefully this will be fixed on xarray’s side several months from now. Once it is done we can revisit that, as then it may make more sense to move constant_data to posterior and/or posterior_predictive groups.
  2. ArviZ only allows to generate the KDEs flattening the chain and draw dimensions. It loops over all other dimensions as if there were no tomorrow and the behaviour is currently hardcoded. Luckily, plot_ppc doesn’t have this hardcoded so solving this is only a matter of exporting the logic in plot_ppc to all other plots. We have to fix this on ArviZ side.

Yep, agreed that this is the way to go, however for now the only way to do that is to swap/overwrite the actual dimension with the coordinate you want to index with. Hopefully this will change soon!

It is in xarray’s roadmap to allow indexing with any coordinates (see for example Explicit indexes in xarray's data-model (Future of MultiIndex) · Issue #1603 · pydata/xarray · GitHub) and they got CZI funding for this (link to proposal if anyone is interested).

Maybe we can add it as gsoc project too. The machinery is already there, it’s only a matter of exposing it from the plotting functions in a similar fashion as plot_ppc does.

In plot ppc you can get the KDE by flattening all dimensions (not only chain and draw ones):


or with only a subset

If this is exported to the other plots, generating kdes from chain, draw, rows like you are doing here would become using an extra arg when calling plot_density.

1 Like

P.S. When I said I didn’t read the above, I meant the long explanation of how to make an infix operator for numpy – not your messages, @OriolAbril, which have been very helpful!

Yes, here they are: