Example of group-level standard deviations with 2 groups is below. I don’t think a prior would work in my case because sd_dist.ndim == 2
n_groups = 2
coords = {
"predictors": predictors,
"predictors_I": predictors,
"groups": np.arange(n_groups)
}
with pm.Model(coords=coords) as m:
...
corr = pm.LKJCorr("corr", n=len(predictors), eta=1, dims=("predictors", "predictors_I"))
sd_dist = pm.HalfCauchy("sigma", 1, dims=("groups", "predictors"))
cov = pm.Deterministic("cov",
pt.stack([sigma_diag[i] @ corr @ sigma_diag[i] \
for i in range(n_groups)]),
dims=("groups", "predictors", "predictors_I"))