OK I had lunch and the answer to “can PyMC do this for you” is an emphatic “YES”. Special thanks to this blog post by @ricardoV94 and @tcapretto , which shows how you can use pm.sample_posterior_predictive
to extend and recycle parts of your model in new and interesting ways.
Here’s the much improved code, then I’ll go over it:
import pytensor.tensor as pt
# I wish we could just use np.nan, but pm.MutableData doesn't support it
# So pick a number here that will never be observed in reality
MISSING_VALUE = -1
with pm.Model(coords=coords) as marginal_model:
obs_data = pm.MutableData('obs_data', np.array([MISSING_VALUE, 195, 166]), dims=['detector'])
missing_idx = pt.nonzero(pt.eq(obs_data, MISSING_VALUE), return_matrix=False)[0]
obs_idx = pt.nonzero(pt.neq(obs_data, MISSING_VALUE), return_matrix=False)[0]
sorted_idx = pt.concatenate([missing_idx, obs_idx])
reverse_idx = pt.argsort(sorted_idx)
n_missing = missing_idx.shape[0]
n_obs = obs_idx.shape[0]
data_sorted = obs_data[sorted_idx]
# Declare model variables to be the same as the old model
sd_dist = pm.Exponential.dist(1/0.1)
chol, corr, sigmas = pm.LKJCholeskyCov('chol', eta=1, n=ndetectors, sd_dist=sd_dist, compute_corr=True)
cov = pm.Deterministic('cov', chol @ chol.T, dims=['detector', 'detector_aux'])
mu = pm.Exponential('mu',lam=1/obs_mean, dims=['detector'])
# Do the marginalization math
mu_sorted = mu[sorted_idx]
cov_sorted = cov[sorted_idx, :][:, sorted_idx]
mu_u = mu_sorted[:n_missing]
mu_o = mu_sorted[n_missing:]
cov_uu = cov_sorted[:n_missing, :][:, :n_missing]
cov_uo = cov_sorted[:n_missing, :][:, n_missing:]
cov_oo = cov_sorted[n_missing:, :][:, n_missing:]
cov_oo_inv = pt.linalg.solve(cov_oo, pt.eye(n_obs))
beta = cov_uo @ cov_oo_inv
mu_missing_hat = mu_sorted[:n_missing] + beta @ (data_sorted[n_missing:] - mu[obs_idx])
Sigma_missing_hat = cov_uu - beta @ cov_oo @ beta.T
pm.Deterministic('mu_missing_hat', mu_missing_hat)
pm.Deterministic('Sigma_missing_hat', Sigma_missing_hat)
# Make a random variable
marginal_missing = pm.MvNormal('marginal_missing', mu=mu_missing_hat, cov=Sigma_missing_hat)
# Sample the new random variable, but note that we ***pass the OLD idata***!!!
idata_marginal = pm.sample_posterior_predictive(idata, var_names=['marginal_missing'])
Basically, what we want to do is take the computational graph we had before (the one that estimated \mu and \Sigma in the first place) and extend it to include some new math. We are absolutely allowed to do this in PyMC, provided we follow a couple rules:
- Give the new model a new name:
with pm.Model(coords=coords) as marginal_model
- Recycle the variable names of anything we want to save. Note that kept the names of the random variables
cov
,mu
, etc. This is because when we forward sample the model, we want PyMC to use draws of the original model for these values. More on this later. - Don’t call
pm.sample
, but instead callpm.sample_posterior_predictive
, providing the OLDidata
!
(3) is the key. Everything we add is going to just be deterministic computations of the random variables that are already living in idata
. So, when PyMC does pm.sample_posterior_predictive
, it will go look for the names of the variables living in idata
, and replace random variables in our graph with those samples. Since we followed rule (2) and recycled the names, it will find and replace exactly what we want. All the quantities downstream of those nodes will be computed as we want, no need to juggle named dimensions. By the way, following rule (1) lets us recycle names in the first place (otherwise it will throw an error that the variable was already declared).
Why is this code so much better? Well, I think you agree it’s much easier to read. But we also don’t have to do any numpy sampling, it’s all automatic. Here’s the same plot again, as a one-liner:
idata_marginal.posterior_predictive.marginal_missing.plot.hist(bins=100);
But what’s even better is that we can now get conditional predictions for any combination of observed/missing we want, without making a new model, by using pm.set_data
. Here’s an example:
with marginal_model:
pm.set_data({'obs_data':[MISSING_VALUE, MISSING_VALUE, 500]})
# Still (always!) passing the original idata
idata_marginal_2 = pm.sample_posterior_predictive(idata, var_names=['marginal_missing'])
missing_idx = idata_marginal_2.posterior_predictive.marginal_missing_dim_2.values
fig, ax = plt.subplots(1, missing_idx.shape[0], figsize=(14,4), dpi=144)
for d in missing_idx:
idata_marginal_2.posterior_predictive.marginal_missing.sel(marginal_missing_dim_2=d).plot.hist(bins=100, ax=fig.axes[d]);
So very quickly we can get the conditional predicted values for sensors 1 and 2, given that sensor 3 observes 500. Fast and easy.