Converting Multi-Level Hierarchical GLM to Non-Centered Form

Hey, I’m looking to convert this model to a non centered form but everything I’ve got so far isin’t great. I created a model where the z value was passed down to each level but I think that takes away the whole point of making it non centered since those normal distributions are now centered allowing there to be divergences with low sigma values.

If anyone knows the best practice or just how to make this model into the non centered form I would really appreciate it!

For reference here is what I am going for (in a single level sense) Why hierarchical models are awesome, tricky, and Bayesian — While My MCMC Gently Samples

For context I also tried having a separate Z value per parameter, this works but I believe that the region effect will be different for each class type, so I would like to find a way to do this hierarchically.

with heir_model_class_type_region_variety:
    # global parameters
    mu_m = pm.Normal('mu_m', mu=1, sigma=10, dims="xdims")
    sigma_m = pm.HalfNormal('sigma_m', sigma=10, dims="xdims")
    
    mu_b = pm.Normal('mu_b', mu=50_000, sigma=100_000)
    sigma_b = pm.HalfNormal('sigma_b', sigma=100_000)
    
    std = pm.HalfNormal('std', sigma=500_000)
    
    # class type parameters
    mu_m_class_type = pm.Normal('mu_m_class_type', mu=mu_m, sigma=sigma_m, dims=("class_type","xdims"))
    sigma_m_class_type = pm.HalfNormal('sigma_m_class_type', sigma=10, dims=("class_type","xdims"))
    
    mu_b_class_type = pm.Normal('mu_b_class_type', mu=mu_b, sigma=sigma_b, dims="class_type")
    sigma_b_class_type = pm.HalfNormal('sigma_b_class_type', sigma=100_000, dims="class_type")
    
    # region parameters
    mu_m_region = pm.Normal('mu_m_region', mu=mu_m_class_type, sigma=sigma_m_class_type, dims=("region","class_type","xdims"))
    sigma_m_region = pm.HalfNormal('sigma_m_region', sigma=10, dims=("region","class_type","xdims"))
    
    mu_b_region = pm.Normal('mu_b_region', mu=mu_b_class_type, sigma=sigma_b_class_type, dims=("region","class_type"))
    sigma_b_region = pm.HalfNormal('sigma_b_region', sigma=100_000, dims=("region","class_type"))
    
    # variety parameters
    m = pm.Normal('m', mu=mu_m_region, sigma=sigma_m_region, dims=("variety","region","class_type","xdims"))
    b = pm.Normal('b', mu=mu_b_region, sigma=sigma_b_region, dims=("variety","region","class_type"))
    
    # data
    xdata = pm.Data('xdata', x, mutable=True)
    class_type_data = pm.Data('class_type_data', class_type, mutable=True)
    region_data = pm.Data('region_data', region, mutable=True)
    variety_data = pm.Data('variety_data', variety, mutable=True)
    samples = pm.Data('samples', sample_data, mutable=True)
    
    if(X_VARS == 1):
        mean = (xdata[samples,0]*m[variety_data,region_data,class_type_data,0])+b[variety_data,region_data,class_type_data]
        
    obs = pm.Normal('obs', mu=mean, sigma=std, observed=y)
    
    heir_trace_class_type_region_variety = pm.sample(**SAMPLE_PARAMS)