@JAB Please see below for a “version” of your problem, where I do not observe divergences when sampling - while keeping the ability to “use small groups”.
import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
def generate_data(lower: int = 5, upper: int = 100, num_groups: int = 100):
rng = np.random.default_rng(98)
# Number of observations per group
n_obs = pm.draw(pm.DiscreteUniform.dist(lower=lower, upper=upper), num_groups, random_seed=rng)
# Random means
individual_means = pm.draw(pm.Normal.dist(mu=2, sigma=3), len(n_obs))
# Data
return pd.DataFrame({
"data": np.concatenate([pm.draw(pm.Normal.dist(mu=mu, sigma=3), n) for mu, n in zip(individual_means, n_obs)]),
"idx": np.concatenate([[ii] * n for ii, n in enumerate(n_obs)]),
# "n_obs": np.concatenate([[n] * n for n in n_obs])
})
def make_model(frame: pd.DataFrame) -> pm.Model:
with pm.Model(coords={"idx": frame["idx"].unique()}) as model:
# Data
g = pm.Data("g", frame["idx"].to_numpy())
y = pm.Data("y", frame["data"].to_numpy())
# Population mean
mu = pm.Normal("mu", 0., 5.)
# Across group variability of the mean
tau = pm.HalfNormal("tau", 3.)
# Population standard deviation
sigma = pm.HalfNormal("sigma", 3.)
# Group specific mean
mu_g = pm.Normal("mu_g", mu=mu, sigma=tau, dims="idx")
# Group specific standard deviation
sigma_g = pm.HalfNormal("sigma_g", sigma=sigma, dims="idx")
# Likelihood
pm.Normal("y_obs", mu=mu_g[g], sigma=sigma_g[g], observed=y)
return model
def make_model_predict(frame: pd.DataFrame) -> pm.Model:
with pm.Model(coords={"idx": frame["idx"].unique()}) as model:
# Data
g = pm.Data("g", frame["idx"].to_numpy())
# Population mean
mu = pm.Normal("mu", 0., 5.)
# Across group variability of the mean
tau = pm.HalfNormal("tau", 3.)
# Population standard deviation
sigma = pm.HalfNormal("sigma", 3.)
# Group specific mean
mu_g_new = pm.Normal("mu_g_new", mu=mu, sigma=tau, dims="idx")
# Group specific standard deviation
sigma_g_new = pm.HalfNormal("sigma_g_new", sigma=sigma, dims="idx")
# Likelihood
pm.Normal("y_obs", mu=mu_g_new[g], sigma=sigma_g_new[g])
return model
if __name__ == "__main__":
frame = generate_data()
model = make_model(frame)
with model:
trace = pm.sample()
summary = az.summary(trace, var_names=["mu", "tau", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
print(summary)
frame = generate_data(num_groups=10)
model = make_model_predict(frame[["idx"]])
with model:
ppc = pm.sample_posterior_predictive(trace, var_names=["y_obs"])
y_pred = pd.Series(
data=ppc.posterior_predictive["y_obs"].mean(("chain", "draw")).to_numpy(),
index=frame["idx"]
)
print(y_pred)
I also used a solution for out of sample predictions with a hierarchical model that is essentially suggested here.
The key point is that when you predict on previously unseen groups you can utilise the posterior of the population parameters (mu
, tau
, sigma
), but as there is no information on the specific latent parameters of each group (mu_g and sigma_g) they will simply be sampled from their priors, i.e. pm.Normal("mu_g_new", mu=mu, sigma=tau)
and pm.HalfNormal("sigma_g_new", sigma=sigma)
. The nice thing about those priors is: they are now informed by the posterior of the population parameters.
Also note: We’re using the trace
of the “inference model” to “feed” a posterior sample of the population paramerers. There are also samples of group specific parameters (mu_g
, sigma_g
) in that trace object. So we need to make sure that our “prediction model” does not contain any variables with the same name. If that were the case then sample_posterior_predictive
would try to use these samples for the unseen groups, which would be conceptually wrong - we don’t have any observations to inform mu_g
and sigma_g
for new groups. Technically it would probably result in some shape related error unless the number of groups (and group labels) would stay identical.
Finally ppc.posterior_predictive["y_obs"].mean(("chain", "draw"))
averages the posterior predictive draws of “y_obs” for each sample in the input - at least that is how I understand the syntax, but I am not 100% certain, maybe @ricardoV94 or @jessegrabowski could chime in here to clarify.
Disclaimer: I am still learning about hierarchical modelling and how to best use pymc
to work with hierarchical models, so any comments and corrections from more seasoned “veterans” are very welcome.