I actually have a few groups… am working with a tricky dataset atm
, so I might look something like this:
with pm.Model(coords=coords) as m:
...
corr_packed = pm.LKJCorr("x", n=K, eta=1, transform=tr)
sigma = pm.HalfCauchy("sigma", 1, dims=("group1", "group2", ..., "groupJ", "K"))
triu_idx = pt.triu_indices(K, k=1)
corr_upper = pt.set_subtensor(pt.zeros((K, K))[triu_idx], corr_packed)
corr = pm.Deterministic("corr", pt.eye(K) + corr_upper + corr_upper.T, dims=("K", "K_I"))
sigma_diag = pt.stack([pt.eye(n_predictors) * sigma[i] for i in range(K)])
cov = pm.Deterministic("cov",
pt.stack([sigma_diag[i] @ corr @ sigma_diag[i] \
for i in range(K)]),
dims=("group1", "group2", ..., "groupJ", "K", "K_I"))
...