Out of sample/model predictions in hierarchical model with new observed data

@JAB great idea with using pm.flat in the prediction model so that only the posterior draws are ever used for the population parameters mu, tau, and sigma (even better if there is a dedicated FromTrace). This way the predictive model gets a much clearer syntax.

Thanks @ricardoV94 and @jessegrabowski for clarifying this.

Regarding the divergences, I slightly modified the data generating code and the model and I am now able to sample for “more realistic” distribution of number of observations (per each group) without observing divergences/low effective samples sizes, cf. code below

def generate_data(num_obs: int = 10, num_groups: int = 10):
    rng = np.random.default_rng(98)

    # Number of observations per group
    n_obs = pm.draw(pm.Poisson.dist(mu=num_obs - 1), num_groups, random_seed=rng) + 1

    # Random means
    individual_means = pm.draw(pm.Normal.dist(mu=2, sigma=3), num_groups)
    # Data
    return pd.DataFrame({
        "data": np.concatenate([pm.draw(pm.Normal.dist(mu=mu, sigma=3), n).reshape(-1) for mu, n in zip(individual_means, n_obs)]),
        "idx": np.concatenate([[ii] * n for ii, n in enumerate(n_obs)]),        
    })


def make_model(frame: pd.DataFrame) -> pm.Model:
    with pm.Model(coords={"idx": frame["idx"].unique()}) as model:
        # Data
        g = pm.Data("g", frame["idx"].to_numpy())
        y = pm.Data("y", frame["data"].to_numpy())

        # Population mean
        mu = pm.Normal("mu", 0., 5.)
        # Across group variability of the mean
        tau = pm.HalfNormal("tau", 3.)
        # Population standard deviation
        sigma = pm.HalfNormal("sigma", 3.)

        # Group specific mean
        eps_mu_g = pm.ZeroSumNormal("eps_mu_g", dims="idx")
        mu_g = mu + tau * eps_mu_g

        # Likelihood
        pm.Normal("y_obs", mu=mu_g[g], sigma=sigma, observed=y)
        return model


if __name__ == "__main__":
    frame = generate_data(num_obs=2, num_groups=1000)
    model = make_model(frame)
    with model:
        trace = pm.sample()
    summary = az.summary(trace, var_names=["mu", "tau", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
    print(summary)

There are some important changes though:
i) I am using a Poission distribution to mimic the fact that you have a large number of groups with a small number of samples, I can go down to num_observations=2 without observing divergences. num_observations is the expected number of observations per group in the generated data set. In the way n_obs is set up, I am ensuring that all groups have at least one sample and that num_observations is indeed the expectation.

I am not entirely sure why I cannot go lower than num_observations=2 but I assume that there is just a minimum number of samples needed to “identify” tau. But this would need a lot more thought, I am just speculating here.

ii) The model uses a “non-centered parametrisation”, i.e. I am using standard normal increments eps_mu_g and I am computing the group-specific slope as mu_g=mu + tau * eps_mu_g. This is a recommended approach whenever your group-specific likelihoods are only weakly informed (if I recall correctly). You can find a very informed :wink: exposition on those points here. In the problem above there are only very few data points for each group, so this indeed qualifies as weakly informed likelihood I think.

iii) I am dropping the group specific standard deviations sigma_g from the model and I am just using the population specific standard deviation sigma in the likelihood. Here I do not have a very informed argument. Just my gut feeling that with such a low number of data points its getting impossible to “identify” group specific standard deviations. Whether this is an admissible approach for your situation or not, I of course cannot tell. For the data generating process the across group variability of sigma_g is in fact zero, so not sure whether that caused some of the divergences. Also it can be important to use “zero suppressing” priors for population parameters tau and sigma so this is another point to try out (after re-parametrising the model). Also the log of the group specific standard deviations could be modelled with a hierarchical normal model - and then use exp transform to get from log-space back to standard deviations. This is also discussed here.

The predictive model would then have to be adapted accordingly, but I assume (hope) that this is rather straightforward.

2 Likes