Use xarray for traces

Something that’s always been bothering me about pymc3 is that we loose any index information we might have had before putting the dataset into pymc3. I often need to do some tinkering to get it back, and I’ve messed that up more than once in the past.
A possible solution to this problem would be to use xarray to store traces. For those who don’t know it, it is basically a multidimensional version of pandas, and works nice in combination with it. Internally it uses conventions taken form netcdf, which apparently is used a lot in the geosciences (storing climate data etc). The website has a nice introduction.
The syntax for getting the indices into pymc could be something like this:

# The index dimensions are called Coordinates in xarray.
# The arrays can be any pandas index, or something that can be
# converted to one.
coords = {
    'subject': ['Peter', 'Hans'],
    'time': [Timestamp('2017-01-20'), Timestamp('2017-01-21')],
    'treatment': ['sorafenib', 'whatever']

# We pass the coordinates to the model constructor, so that they are
# shared between all variables in the model.
with pm.Model(coords=coords) as model:
    intercept = pm.Flat('intercept')
    sd = pm.HalfStudentT('subject_sd', nu=3, sd=2)
    # Instead of specifying the shape of the variable, we specify
    # the name of the coordinate. It can then infer the shape on its own.
    a = pm.Normal('subject_mu', sd=sd, mu=0, dims='subject')

    effect_sd = pm.HalfStudentT('effect_sd', nu=3, sd=2)
    effect = pm.Normal('effect', sd=effect_sd, dims='treatment')
    interaction_sd = pm.HalfCauchy('interaction_sd', beta=1)
    # We can use more than one dimension.
    interaction = pm.Normal('interaction', mu=0, sd=day_sd,
                            dims=('time', 'treatment'))

During sampling we would add two more coordinates: sample and chain.
To get a feeling about how such a trace could be used, you can use this
function to convert a normal trace to an xarray Dataset:

def to_xarray(trace, coords, dims):
    """Convert a pymc3 trace to an xarray dataset.

    trace : pymc3 trace
    coords : dict
        A dictionary containing the values that are used as index. The key
        is the name of the dimension, the values are the index values.
    dims : dict[str, Tuple(str)]
        A mapping from pymc3 variables to a tuple corresponding to
        the shape of the variable, where the elements of the tuples are
        the names of the coordinate dimensions.


        coords = {
            'subject': ['Peter', 'Hans'],
            'time': [Timestamp('2017-01-20'), Timestamp('2017-01-21')],
            'treatment': ['sorafenib', 'whatever']
        dims = {
            'subject_mu': ('subject',),
            'effect': ('treatment',),
            'interaction': ('time', 'treatment'),
    coords = coords.copy()
    coords['sample'] = list(range(len(trace)))
    coords['chain'] = list(range(trace.nchains))
    coords_ = {}
    for key, vals in coords.items():
        coords_[key] = xr.IndexVariable((key,), data=vals)
    coords = coords_
    data = xr.Dataset(coords=coords)
    for key in trace.varnames:
        if key.endswith('_'):
        dims_str = ('chain', 'sample')
        if key in dims:
            dims_str = dims_str + dims[key]
        vals = trace.get_values(key, combine=False, squeeze=False)
        vals = np.array(vals)
        data[key] = xr.DataArray(vals, {v: coords[v] for v in dims_str}, dims=dims_str)
    return data

I think it will be quite convenient to many of my applications (behavioural data analysis with repeat measures, hierarchical models), but it might lose some flexibility - is the shape args still available in this case?

I don’t see a problem with mixing shape and dim. We’d just have to create new individual range coordinates for each dimension without a specified coordinate. Sample should probably also get an additional kwargs to tell it that we want an xarray trace.

I created a simple notebook with data generation / model / trace all in xarray. If anyone wants to play around with it:

All in all using xarray feels great, the xr.Dataset is a great thing to keep stuff together, the only thing so far that I don’t like is that a database join like operation seems to be missing. It is possible to work around that by doing this:

data['treatment'].isel_points(data.subject, treatment=data.treated_idx)

but it feels a bit strange. But maybe I’m just missing something, I’m still new to xarray.
I’ll try to put a bit of text in there and make a blog post about it, but this might still take a bit of time.

This looks really cool, and I intend to play around with it some - do you have some intuition behind how this effects speed and memory usage? The other nervousness I have (which might also be solved by my just doing some reading myself) is whether this means we’ll miss out on upstream numpy improvements.

Is your suggestion to only support xarray, or to support xarray + NDArray, or something else?

@colcarroll Don’t worry, I have absolutely no plans to do something backward incompatible. I don’t think we should support only xarray, and I think it shouldn’t be the default either (at least for the foreseeable future). Using xarray isn’t that difficult, but it would be something else all users would have to learn.

About speed/memory: I don’t think it will make much of a difference either way. Internally it uses numpy anyway, it just needs a little additional memory to store the coordinates. During sampling there wouldn’t be any change at all, except for maybe the cost for storing the trace, and that shouldn’t typically be very expensive anyway.

Something where we might get a good speedup is in storing and loading traces. It supports storing datasets as netcdf (which uses hdf5) and can load variables lazily.