Mixed parameterization of hierarchical model

I was really interested in Betancourt’s use of a mixture of centered and non-centered parameterization in a single hierarchical model that he described in his “Hierarchical Modeling” blog post. I implemented this model in PyMC3 (Betancourt used Stan) and came up with the model shown below. It feels a bit clunky, so I am curious if anyone else has a different implementation or suggestions for improvements.

    d: dict[str, Any],
    cp_idx: np.ndarray,
    ncp_idx: np.ndarray,
    draws: int = 10000,
    tune: int = 1000,
) -> tuple[pm.Model, az.InferenceData, dict[str, np.ndarray]]:
    n_cp = len(np.unique(cp_idx))
    n_ncp = len(np.unique(ncp_idx))
    with pm.Model() as mixp_model:
        mu = pm.Normal("mu", 0, 5)
        tau = pm.HalfNormal("tau", 5)

        theta_cp = pm.Normal("theta_cp", mu, tau, shape=n_cp)
        eta_ncp = pm.Normal("eta_ncp", 0, 1, shape=n_ncp)
        theta_ncp = pm.Deterministic("theta_ncp", mu + tau * eta_ncp)

        _theta = list(range(d["K"]))
        for i, t in enumerate(cp_idx):
            _theta[t] = theta_cp[i]
        for i, t in enumerate(ncp_idx):
            _theta[t] = theta_ncp[i]

        theta = pm.Deterministic("theta", pm.math.stack(_theta))

        y = pm.Normal("y", theta[d["idx"]], d["sigma"], observed=d["y"])

        trace = pm.sample(
        y_post_pred = pm.sample_posterior_predictive(
            trace=trace, random_seed=RANDOM_SEED
    return mixp_model, trace, y_post_pred

For context, d is a data dictionary where "K" is the number of individual groups, cp_idx and ncp_idx are the indices of the individuals that should get a centered and non-centered parameterization, respectively.