Indexing in hierarchical modeling

Hi, I’m trying to build a varying intercept model and here is my data structure:
Input:
for each subject i, the input is a flattened upper triangular 9090 matrix. We have 100 subject, so the input would be an array of 1004005. The 9090 matrix is the correlation between 90 regions. So, the data structure can be understood as for each subject i, there are 90 time series measurements in each region, we take pair-wise correlation of these time series, result in a 9090 matrix for each subject, then we take the flattened upper triangle of these matrices as input.
These 90 regions can be further divided into 13 mutually exclusive groups. The group follows a Categorical distribution. That means, for each region, we generate a length 13 vector with 1 indicates the group this region belongs to. So, with 13 groups, 91 group-wise correlation can be constructed. Each of 4005 features should belongs to one of 91 new groups. How do I code this in pymc?
Below is my code with simulations:

import numpy as np
# Define model parameter
np.random.seed(123)
v = 90
Q = 13
n = 100
group_vec = np.concatenate([
    np.repeat(0, 7),
    np.repeat(1, 6),
    np.repeat(2, 8),
    np.repeat(3, 8),
    np.repeat(4, 5),
    np.repeat(5, 2),
    np.repeat(6, 8),
    np.repeat(7, 5),
    np.repeat(8, 10),
    np.repeat(9, 7),
    np.repeat(10, 8),
    np.repeat(11, 5),
    np.repeat(12, 11)
])
group_size = np.array([7, 6, 8, 8, 5, 2, 8, 5, 10, 7, 8, 5, 11])

# Define the latent means
latent_mu = np.zeros([n,int(Q*(Q-1)/2+Q)])
for i in range(n):
    latent_mu[i] = np.random.normal(0, 1, int(Q*(Q-1)/2+Q))
latent_mu_matrix = np.zeros([n, Q, Q])
for i in range(n):
    ind = np.triu_indices(Q)
    latent_mu_matrix[i, ind[0], ind[1]] = latent_mu[i]
latent_sigma = 1

# generate the connectivity matrix
# generate a_i_jl
A = np.zeros([n, v, v])

for i in range(n):
    for j in range(v):
        for l in range(v):
            if j == l:
                A[i,j,l] = 5
            elif j < l:
                A[i,j,l] = np.random.normal(latent_mu_matrix[i, group_vec[j], group_vec[l]], latent_sigma)
    ind_lower = np.tril_indices(v, -1)
    A[i][ind_lower] = A[i].T[ind_lower]

ind = np.triu_indices(v, k=1)
a_jl = np.zeros([n, int(v*(v-1)/2)])
for i in range(n):
    a_jl[i] = A[i][ind]
    
for i in range(n):
    triui = np.triu_indices_from(np.triu(A[i]), k=1)
col_names = []
for i in range(len(triui[0])):
    col_names.append('ROI' + str(triui[0][i]) + '_' + str(triui[1][i]))


triui_net = np.triu_indices_from(np.triu(np.zeros([Q,Q])), k=0)
net_net_names = []
for i in range(len(triui_net[0])):
    net_net_names.append('NET' + str(triui_net[0][i]) + '_' + str(triui_net[1][i]))
# generate y
latent_index = [0, 5, 9, 29, 44, 48, 50, 79, 85, 87]
random_sigma = 1.5
sigma = 2
T = np.array([0,1])
age = np.repeat(T, n)
N = n*len(T)
beta_latent = np.zeros(int(Q*(Q-1)/2+Q))
for i in range(int(Q*(Q-1)/2+Q)):
    if i in latent_index:
        beta_latent[i] = 5


beta_age = 10
y = np.tile(np.random.normal(0, random_sigma, n), 2) + np.tile(np.dot(latent_mu, beta_latent), 2) + beta_age*age + np.random.normal(0, sigma, N)

import pandas as pd
# sub_id, mn_sub = pd.factorize(pd.Series(np.tile(np.arange(n), 2)))
sub_id, mn_sub = pd.factorize(pd.Series(np.arange(n)))
region, mn_region = pd.factorize(pd.Series(np.arange(v).astype(int)))
# net, mn_net = pd.factorize(pd.Series(group_vec))
net, mn_net = pd.factorize(np.arange(Q).astype(int))
latent, mn_latent = pd.factorize(pd.Series(net_net_names))
coords = {'sub_id': mn_sub, 'region': mn_region, 'net':mn_net, 'latent':mn_latent}

def get_flattened_index(n, i, j):
    """
    Convert row (i) and column (j) indices of an upper triangular matrix
    (including diagonal) to the flattened index.

    :param n: int, size of the n x n matrix
    :param i: int, row index (0-based)
    :param j: int, column index (0-based)
    :return: int, flattened index (0-based)
    """
    return int((i * (2 * n - i + 1)) // 2 + (j - i))

def upper_triangular_to_flattened_index(n, i, j):
    if i >= j:
        raise ValueError("This function is only valid for the upper triangular part (excluding diagonal), so i must be less than j.")
    
    flattened_index = ((2*n - i - 1) * i // 2 + (j - i - 1))
    return flattened_index
import pymc as pm
import pytensor.tensor as pt
import sys
import pymc_experimental as pmx
sys.setrecursionlimit(3000)
with pmx.MarginalModel(coords=coords) as varying_intercept:
    sub_idx = pm.Data("sub_idx", sub_id, dims="sub_id")
    region_idx = pm.Data("region_idx", region.astype(int), dims="region")
    net_idx = pm.Data("net_idx", net.astype('int'), dims="net")
    latent_idx = pm.Data("latent_idx", latent, dims='latent')
    corr = pm.Data("corr", A, dims=("obs_id", "conn", 'conn'))
    # age = pm.Data("age", age, dims="obs_id")
    alpha = np.ones([v,Q])  # Uniform prior
    probs = pm.Dirichlet('probs', a=alpha[region,:], dims=('region', 'net'))
    region_net = pm.Categorical('region_net',p = probs[region,:])
    # print(region_net.eval())
    # region_net_tensor = pt.argmax(region_net, axis=1)
    region_net_vec = region_net.eval()
        
    latent_sigma = pm.Exponential("latent_sigma", 1)
    m_qr = pm.Normal("m_qr", mu=0, sigma=1, dims=('sub_id', 'latent'))
    region_net_idx = np.zeros(int(v*(v-1)/2), dtype=int)
    i = 0
    test = 0
    for i in range(int(Q*(Q-1)/2+Q)):
    # for q in range(Q):
    #     for r in range(q,Q):
        q_idx = np.where(region_net_vec==triui_net[0][i])[0]
        r_idx = np.where(region_net_vec==triui_net[1][i])[0]
        latent_qr_i = []
        for j in range(len(q_idx)):
            for l in range(len(r_idx)):
                # if j<l:
                try:
                    latent_i = upper_triangular_to_flattened_index(v, q_idx[j], r_idx[l])
                    latent_qr_i.append(latent_i)
                except:
                    pass
        test+=len(latent_qr_i)
        region_net_idx[latent_qr_i] = i
        i+=1
                # latent_idx = pm.Data('latent_idx', latent_idx, dims='latent')

sigma=pt.repeat(latent_sigma, n), observed=a_qr.T)
    # print(latent_idx)
    m_qr_like = pm.Normal('m_qr_like', mu = m_qr[sub_id,latent_idx], sigma = latent_sigma, observed = a_jl[:,region_net_idx])
    # Priors
    mu_a = pm.Normal("mu_a", mu=0.0, sigma=10.0)
    sigma_a = pm.Exponential("sigma_a", 1)

    # Random intercepts
    alpha = pm.Normal("alpha", mu=mu_a, sigma=sigma_a, dims="sub_id")
    # Model error
    sd_y = pm.Exponential("sd_y", 1)
    # Common slope
    beta_age = pm.Normal("beta", mu=0.0, sigma=10.0)
    lbd = pm.HalfStudentT("lbd", nu = 5, dims=("latent",))
    tau = pm.HalfStudentT("tau", nu =5, sigma=2*sd_y)
    z = pm.Normal("z", mu=0, sigma=1, dims=("latent",))
    beta_latent = pm.Deterministic("beta_latent", z*lbd*tau)


    # Expected value
    y_hat = alpha[sub_idx] + beta_age * age + pm.math.dot(pt.tile(m_qr,(2, 1)), beta_latent.T)

    # Data likelihood
    y_like = pm.Normal("y_like", mu=y_hat, sigma=sd_y, observed=y)

Thank you!