If I want to sample a multivariate normal, and I’d like the results of the idata object to have appropriately named dimensions and coordinates, how do I achieve that?
This code runs, but I’m not sure what to do with the coords to get them attached to the various parameters (mu, chol, Rho_, sigma_cafe):
import numpy as np
import pymc as pm
k = 3
n = 100
mean = np.random.randn(k)
covariance = np.random.rand(k, k)
covariance = np.dot(covariance, covariance.T)
data = np.random.multivariate_normal(mean, covariance, n)
coords = {"vars": np.arange(k)}
with pm.Model(coords=coords) as model:
mu = pm.MvNormal('mu', mu=np.zeros([k,]), cov=np.eye(k))
chol, Rho_, sigma_cafe = pm.LKJCholeskyCov('L_chol', n=k, eta=2, sd_dist=pm.Exponential.dist(1.0), compute_corr=True)
y_obs = pm.MvNormal('y_obs', mu=mu, cov=chol, observed=data)
trace = pm.sample(1000, tune=1000, random_seed=1234)
pm.summary(trace, var_names=['mu', 'L_chol'])
There was some draft work on this by @jessegrabowski
pymc-devs:main
← jessegrabowski:lkj_coords
opened 07:35AM - 14 Jul 23 UTC
**What is this PR about?**
There are currently several pain points when using… labeled dims with`LKJCholeskyCov`:
1. The distribution samples 2d matrices, but passing a pair of dimensions results in an error, because internally the distribution is represented in packed lower-triangular form.
2. Any dimensions given to `LKJCholeskyCov` are not propagated to internally generated deterministics, `{name}_std` and `{name}_corr`. If a user wants these to be labeled, he has to pass a long dictionary to `idata_kwargs`
3. After sampling, the 1's on the diagonal of all samples drawn from `{name}_corr` causes an error in `arviz` when computing within-chain variance.
This PR tries to correct all three of these. Here is an example model under this PR:
```python
from string import ascii_uppercase
n = 3
n_obs = 100
mean = np.zeros(n)
L = np.random.normal(size=(n, n))
cov = L @ L.T
data = np.random.multivariate_normal(mean=mean, cov=cov, size=(n_obs, ))
with pm.Model(coords={'dim':ascii_uppercase[:n],
'dim_aux':ascii_uppercase[:n]},
coords_mutable={'obs_idx':np.arange(n_obs, dtype='int')}) as mod:
sd_dist = pm.Exponential.dist(1)
chol, *_ = LKJCholeskyCov('chol', n=n, sd_dist=sd_dist, eta=1, dims=['dim', 'dim_aux'])
obs = pm.MvNormal('obs', mu=0, chol=chol, observed=data, dims=['obs_idx', 'dim'])
idata = pm.sample()
```
First, I pass two dimensions to `LKJCholeskyCov` -- one for the columns, and one for the rows. This corresponds to the expectation that I am drawing from a matrix-valued random variable.
Internally, I take the Cartesian product between these two dims, and use the lower triangle of the resulting matrix to make and register a new coordinate: `packed_tril_{name}`. This is then set as the dims on `packed_chol`.
Next, only the upper triangle (excluding the diagonal) of the correlation matrix is stored in a deterministic. Another new coordinate is registered: `corr_{name}`.
Finally, the first dim is used to add a labeled dimension to `{name}_std`.
This results in the following graph:
![image](https://github.com/pymc-devs/pymc/assets/48652735/b1edfde8-9823-4271-af22-cfbd556fc0fb)
Here is the result plotted with `az.plot_trace`:
![image](https://github.com/pymc-devs/pymc/assets/48652735/4d2b5f0c-7cf5-4803-9773-50520159d5e4)
This PR still needs a bit of work, including:
1. Unit tests for the new functionality
2. Documentation
3. The generated dimensions are added in a "hacky" way, I hope this can be improved
4. The names on the "packed" dimensions of chol and cov are not great. It would be nice if a MultiIndex could be specified here, but I don't think it's currently possible without some changes (see [here](https://discourse.pymc.io/t/pymc-arviz-how-to-make-the-most-of-labeled-coords-and-dims-in-pymc-4-0/9581/9?u=jessegrabowski), but maybe this is out of date?)
It's also possible that the problem (3) is a problem on the arviz side of things, and should be fixed there instead of here. But in that case, it would still be nice to propagate the matrix dims to the full square correlation matrix.
Because of these points, I'm marking this as a draft PR. But I would still like feedback on the idea of automatically generating coords, or at least on how dim handling can be improved in `LKJCholeskyCov`
**Checklist**
+ [ ] Explain important implementation details 👆
+ [ ] Make sure that [the pre-commit linting/style checks pass](https://docs.pymc.io/en/latest/contributing/python_style.html).
+ [ ] Link relevant issues (preferably in [nice commit messages](https://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html))
+ [ ] Are the changes covered by tests and docstrings?
+ [ ] Fill out the short summary sections 👇
## Major / Breaking Changes
- None
## New features
- Allow `LKJCholeskyCov` to generate and register new model `coords` corresponding to the distributions it internally registers
## Bugfixes
- Allows plotting of generated correlation matrix in `az.plot_trace`, but that might not be something that should be fixed on the PyMC side. See discussion above.
## Documentation
-None
## Maintenance
-None
----
:books: Documentation preview :books:: https://pymc--6828.org.readthedocs.build/en/6828/
This is helpful. It’s nice to know that work is being done.
I think it’s also possible to use xarray.rename_dims
and xarray.rename_vars
and xarray.assign_coords
to do a lot of this after sampling.
The one thing that I still can’t figure out is how to get xarray
to combine dimensions that are supposed to be the same, like chol_cov_corr_dim_0
and chol_cov_stds_dim_0
.
If two dimensions have the same name, they will be combined automatically.
I would say the current recommended way to get coords onto LKJCholeskyCov is to use the idata_kwargs
kwarg in pm.sample
. This tutorial shows how to do it:
# Lots of other stuff snipped
sd_dist = pm.Exponential.dist(0.5, shape=(2,))
chol, corr, stds = pm.LKJCholeskyCov("chol", n=2, eta=2.0, sd_dist=sd_dist)
# More snipping
covariation_intercept_slope_trace = pm.sample(
1000,
tune=3000,
target_accept=0.95,
idata_kwargs={"dims": {"chol_stds": ["param"], "chol_corr": ["param", "param_bis"]}},
)
Thanks. That was actually a huge help.
Opher