Simpsons paradox and mixed models - large number of groups

I think the issue is the usual panel data unobserved confounding due to the mx term in the slope. The causal graph is something like this:

image

McElreath addresses graphs like this in this video, bonus section. Basically we want to model x and y simultaneously in order to obtain estimates for mx:

coords = {"group": group_list}

with pm.Model(coords=coords) as hierarchical:
    x = pm.MutableData("x", data.x, dims="obs_id")
    group_idx = pm.MutableData("group_idx", data.group_idx, dims="obs_id")
    
    x_intercept = pm.Normal('x_intecept')
    x_beta = pm.Normal('x_beta')
    mx_hat = pm.Normal('mx', dims='group')
    x_mu = x_intercept + x_beta * mx_hat[group_idx]

    x_sigma = pm.HalfNormal('x_sigma')
    x_hat = pm.Normal('x_hat', mu=x_mu, sigma=x_sigma, observed=x, dims='obs_id')
    
    mu_intercept = pm.Normal('mu_intercept')
    sigma_intercept = pm.Gamma("sigma_intercept", alpha=2, beta=1)
    offset_intercept = pm.ZeroSumNormal("offset_intercept", dims="group")
    intercept = pm.Deterministic("intercept", mu_intercept + sigma_intercept * offset_intercept, dims="group")

    mu_slope = pm.Normal('mu_slope')
    sigma_slope = pm.Gamma("sigma_slope", alpha=2, beta=1)
    offset_slope = pm.ZeroSumNormal("offset_slope", dims="group")
    slope = pm.Deterministic("slope", mu_slope + sigma_slope * offset_slope, dims="group")
    
    mx_slope = pm.Normal('mx_slope ')

    mu = (intercept[group_idx]
            + slope[group_idx] * x
            + mx_slope * mx_hat[group_idx]
    )
    sigma = pm.HalfNormal("sigma", sigma=2)
    
    pm.Normal("y", mu=mu, sigma=sigma, observed=data.y, dims="obs_id")
    idata = pm.sample(nuts_sampler='nutpie')

This scheme obtains unbaised estimates of the population slope:

This doesn’t solve the ESS problems, though – that’s still very bad.

1 Like