Lkjcorr returning error

Example of group-level standard deviations with 2 groups is below. I don’t think a prior would work in my case because sd_dist.ndim == 2

n_groups = 2

coords = {
    "predictors": predictors,
    "predictors_I": predictors,
    "groups": np.arange(n_groups)
}

with pm.Model(coords=coords) as m:
    
    ...

    corr = pm.LKJCorr("corr", n=len(predictors), eta=1, dims=("predictors", "predictors_I"))
    sd_dist = pm.HalfCauchy("sigma", 1, dims=("groups", "predictors"))
    cov = pm.Deterministic("cov", 
                           pt.stack([sigma_diag[i] @ corr @ sigma_diag[i] \
                               for i in range(n_groups)]), 
                           dims=("groups", "predictors", "predictors_I"))