First of all, beware of the centered hierarchical parametrization you are using, it may lead to divergences and difficulties while fitting.
That being said, your model looks more or less like a GLM with shared prior random variates mu_beta
and sigma_beta
across features and sites. Once you get a posterior distribution over those two, your predictions should look something like
y_hat = a + dot(X_shared, Normal(mu=mu_beta, sigma=sigma_beta))
y_like = Bernoulli('y_like', logit_p=y_hat)
So, we will aim to get that.
The way in which we always recommend out of sample posterior predictive checks is to use theano.shared
's. I’ll use a different approach, inspired in the functional API that is being the core design idea for pymc4. The are many differences I wont go into between pymc3 and the skeleton of pymc4, but one thing that I started to use more were factory functions to get the Model
instances. Instead of trying to define things inside the model with theano.shared
's, I just create a new model with the new data and draw posterior predictive samples from it. I just recently posted about this here.
The idea is to create the model with the training data and sample from it to get a trace
. Then you use have to extract from the trace the hierarchical part which is shared with the unseen site: mu_beta
, sigma_beta
and a
. Finally, you create a new model using the new data of the test site, and sample from the posterior predictive using a list of dictionaries that hold the mu_beta
, sigma_beta
and a
part of the training trace
. Here’s a self-contained example
import numpy as np
import pymc3 as pm
from theano import tensor as tt
from matplotlib import pyplot as plt
def model_factory(X, y, site_shared, n_site, n_features=None):
if n_features is None:
n_features = X.shape[-1]
with pm.Model() as model:
mu_beta = pm.Normal('mu_beta', mu=0., sd=1)
sigma_beta = pm.HalfCauchy('sigma_beta', 5)
a = pm.Normal('a', mu=0., sd=1)
b = pm.Normal('b', mu=0, sd=1, shape=(n_features, n_site))
betas = mu_beta + sigma_beta * b
y_hat = a + tt.dot(X, betas[:, site_shared])
pm.Bernoulli('y_like', logit_p=y_hat, observed=y)
return model
# First I generate some training X data
n_features = 10
ntrain_site = 5
ntrain_obs = 100
ntest_site = 1
ntest_obs = 1
train_X = np.random.randn(ntrain_obs, n_features)
train_site_shared = np.random.randint(ntrain_site, size=ntrain_obs)
new_site_X = np.random.randn(ntest_obs, n_features)
test_site_shared = np.zeros(ntest_obs, dtype=np.int32)
# Now I generate the training and test y data with a sample from the prior
with model_factory(X=train_X,
y=np.empty(ntrain_obs, dtype=np.int32),
site_shared=train_site_shared,
n_site=ntrain_site) as train_y_generator:
train_Y = pm.sample_prior_predictive(1, vars=['y_like'])['y_like'][0]
with model_factory(X=new_site_X,
y=np.empty(ntest_obs, dtype=np.int32),
site_shared=test_site_shared,
n_site=ntest_site) as test_y_generator:
new_site_Y = pm.sample_prior_predictive(1, vars=['y_like'])['y_like'][0]
# The previous part is just to get some toy data to fit
# Now comes the important parts. First training
with model_factory(X=train_X,
y=train_Y,
site_shared=train_site_shared,
n_site=ntrain_site) as train_model:
train_trace = pm.sample()
# Second comes the hold out data posterior predictive
with model_factory(X=new_site_X,
y=new_site_Y,
site_shared=test_site_shared,
n_site=ntrain_site) as test_model:
# We first have to extract the learnt global effect from the train_trace
df = pm.trace_to_dataframe(train_trace,
varnames=['mu_beta', 'sigma_beta', 'a'],
include_transformed=True)
# We have to supply the samples kwarg because it cannot be inferred if the
# input trace is not a MultiTrace instance
ppc = pm.sample_posterior_predictive(trace=df.to_dict('records'),
samples=len(df))
plt.figure()
plt.hist(ppc['y_like'], 30)
plt.axvline(new_site_Y, linestyle='--', color='r')
The posterior predictive I get looks like this:
Of course, I don’t know what kind of data to concretely put as your X_shared
, site_shared
or train_y
, so I just made up some nonsense toy data at the beginning of the code, you should replace that with your actual data.