How can I jitter the values in InferenceData object so I can plot without degeneracy issues? (aka map a function that outputs random values for each element)

A little related to (Is it possible to transform Trace Data?) and more generally arviz.InferenceData.map — ArviZ 0.15.1 documentation

Takes a minute to explain this, but I think the question might boil down to something like “how can I force a function to evaluate each time its called when mapped” - which is more like general Python than anything specific to pymc / arviz. Nonetheless, I’ve tried a few things to no avail, so asking in case anyone has an idea from leftfield.

  1. I have an InferenceData object that contains a set of prior predictive samples (idata = pm.sample_prior_predictive()).
  2. My model contains a chol, corr_, stds_ = pm.LKJCholeskyCov('lkjcc', n=2, eta=2.0, sd_dist=sd_dist, compute_corr=True)
  3. My idata.prior.lkjcc_corr contains degenerate values (all 1’s) on the diagonals
  4. I want to plot lkjcc_corr e.g _ = az.plot_posterior(idata, group='prior', var_names='lkjcc_corr') … but this throws matplotlib plotting errors because it’s trying to estimate a KDE for the degenerate 1’s on the diagonals.

I’d like to apply a jitter something like the following, but it evaluates lazily and every value gets the same modification

rng = np.random.default_rng(seed=42)
choices = np.linspace(-1e-10, 1e-10, 100)
idata = mdl.idata.map(lambda x: x + rng.choice(choices), groups='prior')

I think the .map approach is the closer you’ll get to this. To have the exact same transformation applied every time you’d need to include the seed inside the function, something like:

choices = np.linspace(-1e-10, 1e-10, 100)
idata = mdl.idata.map(lambda x: x + np.random.default_rng(seed=42).choice(choices), groups='prior')

regarding lazy operations, ArviZ does nothing on this and delegates everything to xarray, so you’d have to convert the datasets inside inferencedata to dask (you can use the .chunk method for this) so computations won’t actually be evaluated until you call .compute or access the data as a numpy array with .values for example.

Another option you might want to consider is excluding the diagonal from plotting. The Label guide — ArviZ 0.15.1 documentation covers this using precisely a covariance matrix as example, 3x3 in that case, not sure how bit yours is.

1 Like

Thanks @OriolAbril, I’ll give those a try tomorrow, especially your last note about excluding from the plot - which seems like a more efficient option all round

Hmmm… If I might bother you again @OriolAbril, I’m a little at sea.

To use the coords kwarg seems like a good direction. However, the object that I’m referencing is the _corr output of an LKJCholeskyCov, which is made as a Deterministic without any coords naming:

corr = pm.Deterministic(f"{name}_corr", corr)

here: pymc/pymc/distributions/multivariate.py at 7b08fc160bc07adcfcd10e135e727c681ffb1b77 · pymc-devs/pymc · GitHub

I can isolate a single element from the covariance matrix e.g. the lower left [0, 1] “cell”

coords = {"lkjcc_corr_dim_0": [0], "lkjcc_corr_dim_1": [1]}
_ = az.plot_posterior(mdl.idata, group='prior', var_names=['lkjcc_corr'], coords=coords)

… but I’d like to also get the other off-diagonal at [1, 0]

This: coords = {"lkjcc_corr_dim_0": [0, 1], "lkjcc_corr_dim_1": [1, 0]} doesn’t work, so I think I have the wrong conceptual understanding of how coords work…

Hopefully a simple thing?

Aha, the penny drops!

Per Label guide — ArviZ 0.15.1 documentation

" To select a non rectangular slice with xarray and to get the result flattened and without NaNs, we can use DataArray s indexed with a dimension that is not present in our current dataset:"

This works:

import xarray as xr
coords = {
    'lkjcc_corr_dim_0': xr.DataArray([0, 1], dims=['asdf']),
    'lkjcc_corr_dim_1': xr.DataArray([1, 0], dims=['asdf'])
}

Thanks!

1 Like

Yes, you need “vectorized” indexing for this to work properly, whereas xarray’s default is “outer” indexing. To trigger the “vectorized” one you need to use DataArrays with the same dimension (and which is also different from the existing dimensions). This is a bit different from numpy, but in this case I think it is a good thing, numpy fancy indexing is quite “unpredictable” and it is not clear which of “vectorized” or “outer” indexing modes you’ll get until you try. If this is something that still feels a bit magic, the numpy improvement proposal on this is a very interesting doc which I think explains all this quite well: NEP 21 — Simplified and explicit advanced indexing — NumPy Enhancement Proposals

1 Like