Unexpected prior predictive behaviour

Hej everyone!

I’m using the following regression model:

with pm.Model() as p1Categorical:
    alpha = pm.Normal('alpha', mu=0, sd=10)
    beta = pm.Normal('beta', mu=0, sd=2, shape=nCat)
    sigma = pm.HalfCauchy('sigma', 1)

    mu = pm.Deterministic('mu', alpha + beta[p1idx])
    obs = pm.Normal('obs', mu=mu, sd=sigma, observed=p1['obs'].values)

    trace_p1Categorical = pm.sample()

Intending to make the model hierarchical later but starting off by modelling only one group.

Now, I want to visualise data, prior predictive and posterior predictive together to see what I’m doing.

priorChecks = pm.sample_prior_predictive(samples=50000, model= p1Categorical)
ppc = pm.sample_posterior_predictive(trace_p1Categorical, 500, var_names=['alpha', 'beta', 'sigma'], model= p1Categorical)

def p1Categoricalm(samples, kind, categories):
    x_cat = categories.categories
    x_codes = np.unique(categories.codes)
    y = np.zeros(0)
    x = np.zeros(0)
    for a, b in zip(samples['alpha'], samples['beta']):
        y_temp = a + b[x_codes]
        x = np.concatenate((x, x_cat))
        y = np.concatenate((y, y_temp))
    return pd.DataFrame({'target': x, 'obs': y, 'kind': kind})

prior_predictive = p1Categoricalm(priorChecks, 'prior pc', p1Cat)
posterior_predictive = p1Categoricalm(ppc, 'posterior pc', p1Cat)
real = p1.copy()[['obs', 'x']]
real['kind'] = 'data'
df = pd.concat((real, prior_predictive, posterior_predictive))
fig, axes = plt.subplots(figsize=(12, 5))
import seaborn as sns
sns.lineplot(x="x", y="obs", hue="kind", data=df, ax=axes)

The data and posterior predictive look as expected. But the prior predictive CI in the plot becomes more narrow with increasing number of prior predictive samples drawn. And also it is not centred around 0, which it should according to the prior definition. I’m gone up to 50k prior predictive samples and run the thing several times and I keep getting pretty much this:

What am I doing wrong or what did I misunderstand?

It looks like you are cranking out a bunch of samples, plugging them back into your model to get a bunch of predictions (for each category), taking the mean of these predictions, and then plotting these means. These means might be interest (e.g., to make sure things aren’t totally wacky), but it essentially ignores how the uncertainty present in your prior/posterior is propagated through your model and entails uncertainty in the predictions themselves (which is often the focus with predictive sampling).

To visualize this uncertainty, I would plot a small number of (e.g., 100) predictions (separately for each category). After that, I would just calculate means and SDs of (all) the predictions for each category. Or plot a histogram of (all) the predictions (separately for each category). Or all of the above!

Right now, seaborn is plotting the mean of the predictions (separately for each category) and the fact that these means are close to 0.0, but not exactly at zero is probably not a big deal (e.g., this discrepancy is much smaller than the SDs in your priors over alpha and beta), but you won’t know for sure until you figure how much uncertainty (variability) there is in each category’s predictions.

Also note that seaborn is constructing error bars/bands that reflect confidence intervals of the mean. CIs are expected to shrink with increasing sample size and do not reflect the variability you are (probably) most interested in.

2 Likes