Something that’s always been bothering me about pymc3 is that we loose any index information we might have had before putting the dataset into pymc3. I often need to do some tinkering to get it back, and I’ve messed that up more than once in the past.
A possible solution to this problem would be to use xarray to store traces. For those who don’t know it, it is basically a multidimensional version of pandas, and works nice in combination with it. Internally it uses conventions taken form netcdf, which apparently is used a lot in the geosciences (storing climate data etc). The website has a nice introduction.
The syntax for getting the indices into pymc could be something like this:
# The index dimensions are called Coordinates in xarray.
# The arrays can be any pandas index, or something that can be
# converted to one.
coords = {
'subject': ['Peter', 'Hans'],
'time': [Timestamp('2017-01-20'), Timestamp('2017-01-21')],
'treatment': ['sorafenib', 'whatever']
}
# We pass the coordinates to the model constructor, so that they are
# shared between all variables in the model.
with pm.Model(coords=coords) as model:
intercept = pm.Flat('intercept')
sd = pm.HalfStudentT('subject_sd', nu=3, sd=2)
# Instead of specifying the shape of the variable, we specify
# the name of the coordinate. It can then infer the shape on its own.
a = pm.Normal('subject_mu', sd=sd, mu=0, dims='subject')
effect_sd = pm.HalfStudentT('effect_sd', nu=3, sd=2)
effect = pm.Normal('effect', sd=effect_sd, dims='treatment')
interaction_sd = pm.HalfCauchy('interaction_sd', beta=1)
# We can use more than one dimension.
interaction = pm.Normal('interaction', mu=0, sd=day_sd,
dims=('time', 'treatment'))
During sampling we would add two more coordinates: sample
and chain
.
To get a feeling about how such a trace could be used, you can use this
function to convert a normal trace to an xarray Dataset:
def to_xarray(trace, coords, dims):
"""Convert a pymc3 trace to an xarray dataset.
Parameters
----------
trace : pymc3 trace
coords : dict
A dictionary containing the values that are used as index. The key
is the name of the dimension, the values are the index values.
dims : dict[str, Tuple(str)]
A mapping from pymc3 variables to a tuple corresponding to
the shape of the variable, where the elements of the tuples are
the names of the coordinate dimensions.
Example
-------
::
coords = {
'subject': ['Peter', 'Hans'],
'time': [Timestamp('2017-01-20'), Timestamp('2017-01-21')],
'treatment': ['sorafenib', 'whatever']
}
dims = {
'subject_mu': ('subject',),
'effect': ('treatment',),
'interaction': ('time', 'treatment'),
}
"""
coords = coords.copy()
coords['sample'] = list(range(len(trace)))
coords['chain'] = list(range(trace.nchains))
coords_ = {}
for key, vals in coords.items():
coords_[key] = xr.IndexVariable((key,), data=vals)
coords = coords_
data = xr.Dataset(coords=coords)
for key in trace.varnames:
if key.endswith('_'):
continue
dims_str = ('chain', 'sample')
if key in dims:
dims_str = dims_str + dims[key]
vals = trace.get_values(key, combine=False, squeeze=False)
vals = np.array(vals)
data[key] = xr.DataArray(vals, {v: coords[v] for v in dims_str}, dims=dims_str)
return data