Uses of LKJCholeskyCov and LKJCorr

It is a simple question:

When should I use pm.LKJCorr and pm.LKJCholeskyCov?

I would use LKJCorr if and when my observed=Y is a matrix, to define the likelihood of the observations. I would not use it as a prior; but instead use LKJCholeskyCov with sd_dist something stupid like beta(10**3,10**-3).

2 Likes

Just to add to what chartl said…
There are three differences between LKJCorr and LKJCholeskyCov:

  1. LKJCorr is a distribution over correlation matrices, while LKJCholeskyCov is a distribution over covariance matrices. These two should give you the same result (but probably much worse sampler performance or even divergences for the one using LKJCorr):
sd = pm.HalfNormal('sd', shape=n)
corr = pm.LKJCorr('corr', eta=2, n=n)
# rescale the correlation matrix to get a covariance matrix
cov = pm.Deterministic('cov', sd[None, :] * corr * sd[:, None])
sd_dist = pm.HalfNormal.dist(sd=1)
packed_chol = pm.LKJCholeskyCov('chol_cov', eta=2, n=n, sd_dist=sd_dist)
chol = pm.expand_packed_triangular(n, packed_chol, lower=True)
cov = pm.Deterministic('cov', tt.dot(chol, chol.T))

Usually you probably want to have covariance matrices in the first place, but if you really want a correlation matrix you can either set sd_dist to a distribution as @chartl suggested, or you can use an arbitrary sd_dist that the sampler likes and then rescale the covariance matrix to a correlation matrix:

sd_dist = pm.HalfNormal.dist(sd=1)
packed_chol = pm.LKJCholeskyCov('chol_cov', eta=2, n=n, sd_dist=sd_dist)
chol = pm.expand_packed_triangular(n, packed_chol, lower=True)
cov = pm.Deterministic('cov', tt.dot(chol, chol.T))
sd = tt.sqrt(tt.diag(cov))
corr = cov / sd[:, None] / sd[None, :]
  1. The implementation of LKJCorr is quite lacking unfortunately. It doesn’t implement a proper bijection between correlation matrices and R^n, so the sampler will usually run into trouble unless you use it only as an observed variable.

  2. LKJCholeskyCov gives you the cholesky decomposition of the covariance matrix. If you pass that into pm.MvNormal it will usually be faster and more stable numerically.

4 Likes