PyMC+ArviZ: how to make the most of labeled coords and dims in PyMC 4.0

It’s definitely related to the extraction of coords. The determine_coords function in data.py is hard-coded to only look for two dimensions: index and column. pm.Data also explicitly checks that the length of dims is equal to the ndims of the data, so it won’t let you “overload” the dimension by decomposing the multi-index into several coords.

The multi-index is being converted to a tuple of tuples in the model.add_cord method (this also converts the multi-index to a tuple-of-tuples if it is passed in via the idata_kwargs keyword) If it could survive that step, xarray is happy to take a pd.MultiIndex as a coord, and then all the work is done. To demonstrate this, I pass no coords to the model then slip in the multi-index after the fact:

prior.prior_predictive.coords.update({'likelihood_dim_0':df.index})

Then you can use .sel as expected on a multi-dimensional index, e.g. prior.prior_predictive.sel({'country':'A'} returns the prior predictive for all sub-regions in country A: A11, A12, etc. This would be quite nice for quickly doing PPCs by different groupings.

As far as implementation goes, I guess either the add_coord method could be modified to allow the mutli-index through, or a new routine could be added to backends.arviz.InferenceDataConverter to look for the tuple-of-tuples structure, rebuild the multi-index with pd.MultiIndex.from_tuples, and then set the coords. Neither solution seems very clean, but the second would probably risk fewer unintended consequences.

3 Likes