Working with Multi-index Coords

Hello,

I’ve got an ArviZ and PyMC combo question. I’ve been working with some uneven, long, data that I’m managing with multi-indexes. This workflow works pretty well except PyMC models can’t accept multi-indexes in the the coords dictionary. To go around this, I just create a 1D array like this coords = {'tests': np.arange(num_tests). After the model inference is done, I can assign the 1D array to a multi-index like this: idata = idata.assign_coords(multi_index_tests).unstack(). This is where things get difficult. Say my multi-index is named “tests”, after unstacking, I will have xarray dims named tests_level_0, tests_level_1, etc....

In theory, I could then use rename to name these dims to what they’re supposed to represent. However, I’ve found that this can quickly become difficult if there are name clashes. For example, in my model, I have two different multi-indexes that share some, but not all, levels. This means I end up with something like trying to rename both designs_level_1 and tests_level_1 to time, which xarray doesn’t allow. Another complication is when one level of the multi-index is already a coord in the PyMC model so renaming the unstacked multi-index is a clear name clash.

I ultimately get around these through an awkward combination of reset_index, assign_coords, and rename (see below) but this wasn’t well documented in either ArviZ or xarray, but I wanted know if there is a better way to do this? Is there a way I could pass the multi-index to PyMC and have it handle this?

Thank you all for your time.

For clarity, this is ultimately what I ended up doing:

times = [0, 15, 30, 45, 60, 75, 90, 120]
replications = [0, 1, 2]


groups = [
    "posterior_predictive",
    "posterior",
    "constant_data",
    "prior",
    "prior_predictive",
    "observed_data",
]

ds = idata.assign_coords(
    mindex_coords_tests,
    groups=groups,
).unstack()

ds = ds.assign_coords(mindex_coords_designs, groups=groups).unstack()

ds = (
    ds.reset_index(
        ["designs_level_0", "tests_level_0", "designs_level_1", "tests_level_1"],
        groups=groups,
        drop=True,
    )
    .assign_coords({"times": times})
    .rename(
        {
            "designs_level_0": "voids",
            "tests_level_0": "voids",
            "tests_level_1": "times",
            "designs_level_1": "times",
            "tests_level_2": "replications",
        },
        groups=groups,
    )
)

Yeah I share your pain @dreycenfoiles , I also work with Multiindex all the time :sweat_smile:

I have a workflow similar to yours that I tend to use, but very recently I ended throwing that away and doing something much simple. Just use something like:

"obs_id": [
    f"{date}_{loc}_{state}" for date, state, loc in Y[["date", "loc", "state"]].to_numpy()
]

Then, you can index your InferenceData object with isel, using your Y dataframe with the specific combinations of date and location (for instance) that you want:
idata.isel(obs_id=Y[select_with_pandas].index)

So yeah, I’m basically outsourcing the selection to Pandas :sweat_smile:
But the fact that xarray can’t handle selection on Multiindex (at least that I know of) is a huge limitation for me in these cases.

Another advantage of this solution is that it avoids unstack or reset_index("obs_id") which are both computationally intensive – if you have big data, your RAM will break down :sweat_smile:

Hope this helps :vulcan_salute:

1 Like

xarray handles selection on MultiIndex just fine, it’s just that arviz won’t put them there for you by default. Generate a simple model with a panel structure:

import pymc as pm
import numpy as np
import pandas as pd
import xarray as xr

# In reality all this code would be replaced with pd.factorize...
times = np.array([0, 15, 30, 45, 60, 75, 90, 120])
replications = np.array([0, 1, 2])
n_obs = 100
obs_idx = np.arange(n_obs)

time_idx =np.random.choice(len(times), size=n_obs, replace=True)
replication_idx = np.random.choice(len(replications), size=n_obs, replace=True)

# Create a multi-index object. Again, from real data you'd just have this.
# Important: make sure the index levels are named. You can set these with 
# df.index.names = ['time', 'replication'] in a "real" case
index = pd.MultiIndex.from_arrays([times[time_idx], replications[replication_idx]], names=['time', 'replication'])


coords = {
    'time':times,
    'replication':replications,
    'obs_idx':obs_idx
}

with pm.Model(coords=coords) as m:
    time_effect = pm.Normal('time_effect', dims=['time'])
    replication_effect = pm.Normal('replication_effect', dims=['replication'])
    
    mu = pm.Deterministic('mu',
                          time_effect[time_idx] + replication_effect[replication_idx],
                          dims=['obs_idx'])
    
    sigma = pm.Exponential('sigma', 1)
    
    y_hat = pm.Normal('y_hat', mu=mu, sigma=sigma, dims=['obs_idx'])
    prior = pm.sample_prior_predictive()

Now inject a MutliIndex onto the prior:

obs_dim = xr.Coordinates.from_pandas_multiindex(index, 'obs_idx')
prior = prior.assign_coords(obs_dim)

You can now do e.g. .sel to ask for specific times from the y_hat variable:

print(prior.prior.y_hat.sel(time=60))

<xarray.DataArray 'y_hat' (chain: 1, draw: 500, replication: 8)> Size: 32kB
array([[[ 0.46856649, -0.89734453,  0.50073046, ..., -0.47615646,
          0.6271321 ,  0.96559254],
        [ 0.07794629, -0.4126819 ,  0.0792669 , ..., -0.30008756,
         -0.071868  ,  0.0563495 ],
        [-1.60604729,  0.52474826, -3.03402283, ...,  0.52440493,
         -0.91028064, -0.01785314],
        ...,
        [ 1.4605323 , -0.42731286, -1.06585537, ..., -0.30773731,
         -0.55402316, -2.04667398],
        [-1.19795181, -1.14820052, -1.34431257, ..., -0.79749659,
         -1.06704442, -0.70265358],
        [ 0.62468763,  0.76633682,  0.64597336, ...,  0.76229302,
         -0.56777422, -0.65927168]]])
Coordinates:
  * chain        (chain) int64 8B 0
  * draw         (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
  * replication  (replication) int64 64B 2 0 2 0 2 0 1 1
    time         int64 8B 60

The negative of this approach is that it overrides the existing time dimension, so for example if you try to do prior.prior.sel(time=60), you will select time on mu and y_hat, but not on time_effect, which no longer has labels. You can avoid this by making not using existing dimension names, for example time_idx and replication_idx, on the pd.MultiIndex. E.g.:

index = pd.MultiIndex.from_arrays([times[time_idx], replications[replication_idx]], names=['time_idx', 'replication_idx'])

# ... same same same ...

obs_dim = xr.Coordinates.from_pandas_multiindex(index, 'obs_idx')
prior = prior.assign_coords(obs_dim)

Result:

print(prior.prior.sel(time=60, time_idx=60))

<xarray.Dataset> Size: 136kB
Dimensions:             (chain: 1, draw: 500, replication_idx: 14,
                         replication: 3)
Coordinates:
  * chain               (chain) int64 8B 0
  * draw                (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499
  * replication         (replication) int64 24B 0 1 2
    time                int64 8B 60
  * replication_idx     (replication_idx) int64 112B 2 1 0 0 0 2 1 1 1 0 0 2 1 1
    time_idx            int64 8B 60
Data variables:
    mu                  (chain, draw, replication_idx) float64 56kB -0.6313 ....
    replication_effect  (chain, draw, replication) float64 12kB 0.04898 ... 0...
    sigma               (chain, draw) float64 4kB 2.797 0.2466 ... 0.2442 0.6758
    time_effect         (chain, draw) float64 4kB -0.1613 0.1042 ... -1.252
    y_hat               (chain, draw, replication_idx) float64 56kB -5.824 .....
Attributes:
    created_at:                 2024-11-15T05:06:32.927238+00:00
    arviz_version:              0.18.0
    inference_library:          pymc
    inference_library_version:  5.16.2
2 Likes

Agree with all the workflow @jessegrabowski wrote above. The reason I don’t use this anymore (apart from the computational overhead I outlined above) is that you can’t do aggregation operations on a MultiIndex in xarray:

prior.prior.y_hat.mean("time")

ValueError: 'replication' not found in array dimensions ('chain', 'draw', 'obs_idx')

To me, that’s a big drawback, because most of the time that’s precisely why I’m interested in having the MultiIndex in there: selecting some slice, then aggregating over one or several dimensions.

To be clear, it can be done – unstack or reset_index are good solutions – but I haven’t found yet a pure xarray solution that doesn’t call for some exception handling at some point. On the contrary, using the Pandas index as I do now works out-of-the-box :man_shrugging:

What’s the expected output of prior.prior.y_hat.mean("time")? This seems like it should be a groupby: prior.prior.y_hat.groupby('time_idx').mean()

That’s exactly that – and actually what I use it for all the time. The problem is that, to the best of my knowledge, you can’t group by multiple dimensions, i.e y_hat.groupby(['time_idx', "replication").mean().

In this example it doesn’t make sense because the MultiIndex is 2D, but you get the idea

Yeah for something like this I’d use to_dataframe() (or to_dask_dataframe for big data) and do the groupby there, something like:

prior.prior.y_hat.to_dataframe().y_hat.groupby(['time_idx', 'unit_idx']).mean().unstack()

Yep. TBH in my experience that’s a lot of overhead code for a marginal benefit, when you can do it in just one line of code with the index, as I did above. But that’s just my personal preference

El El vie, 15 nov 2024 a la(s) 9:49, Jesse Grabowski via PyMC Discourse <notifications@pymc3.discoursemail.com> escribió:

I see. So if I’m understanding you correctly, it’s more off an xarray limitation rather than anything PyMC could do to make it easier for users?

Either way, I like this trick. I was noticing that unstack is pretty sluggish, so avoiding it is awesome! Thank you for your help!

Mainly, yes. Happy to hear that was helpful!
Live Bayes & Prosper :vulcan_salute: