Lkjcorr returning error

I actually have a few groups… am working with a tricky dataset atm :thinking:, so I might look something like this:

with pm.Model(coords=coords) as m:

    ...

    corr_packed = pm.LKJCorr("x", n=K, eta=1, transform=tr)
    sigma = pm.HalfCauchy("sigma", 1, dims=("group1", "group2", ..., "groupJ", "K"))
    triu_idx = pt.triu_indices(K, k=1)
    corr_upper = pt.set_subtensor(pt.zeros((K, K))[triu_idx], corr_packed)
    corr = pm.Deterministic("corr", pt.eye(K) + corr_upper + corr_upper.T, dims=("K", "K_I"))
    
    sigma_diag = pt.stack([pt.eye(n_predictors) * sigma[i] for i in range(K)])
    cov = pm.Deterministic("cov", 
                           pt.stack([sigma_diag[i] @ corr @ sigma_diag[i] \
                               for i in range(K)]), 
                           dims=("group1", "group2", ..., "groupJ", "K", "K_I"))

    ...