Calculating conditional posterior predictive samples in high-dimensional observation spaces

OK I had lunch and the answer to “can PyMC do this for you” is an emphatic “YES”. Special thanks to this blog post by @ricardoV94 and @tcapretto , which shows how you can use pm.sample_posterior_predictive to extend and recycle parts of your model in new and interesting ways.

Here’s the much improved code, then I’ll go over it:

import pytensor.tensor as pt
# I wish we could just use np.nan, but pm.MutableData doesn't support it
# So pick a number here that will never be observed in reality
MISSING_VALUE = -1

with pm.Model(coords=coords) as marginal_model:
    obs_data = pm.MutableData('obs_data', np.array([MISSING_VALUE, 195, 166]), dims=['detector'])
    missing_idx = pt.nonzero(pt.eq(obs_data, MISSING_VALUE), return_matrix=False)[0]
    obs_idx = pt.nonzero(pt.neq(obs_data, MISSING_VALUE), return_matrix=False)[0]
    sorted_idx  = pt.concatenate([missing_idx, obs_idx])
    reverse_idx = pt.argsort(sorted_idx)
    
    n_missing = missing_idx.shape[0]
    n_obs = obs_idx.shape[0]
    
    data_sorted = obs_data[sorted_idx]
    
    # Declare model variables to be the same as the old model
    sd_dist = pm.Exponential.dist(1/0.1)
    chol, corr, sigmas = pm.LKJCholeskyCov('chol', eta=1, n=ndetectors, sd_dist=sd_dist, compute_corr=True)
    cov = pm.Deterministic('cov', chol @ chol.T, dims=['detector', 'detector_aux'])
    mu = pm.Exponential('mu',lam=1/obs_mean, dims=['detector'])
    
    # Do the marginalization math
    mu_sorted = mu[sorted_idx]
    cov_sorted = cov[sorted_idx, :][:, sorted_idx]
    
    mu_u = mu_sorted[:n_missing]
    mu_o = mu_sorted[n_missing:]
    
    cov_uu = cov_sorted[:n_missing, :][:, :n_missing]
    cov_uo = cov_sorted[:n_missing, :][:, n_missing:]
    cov_oo = cov_sorted[n_missing:, :][:, n_missing:]
    
    cov_oo_inv = pt.linalg.solve(cov_oo, pt.eye(n_obs))
    beta = cov_uo @ cov_oo_inv
    
    mu_missing_hat = mu_sorted[:n_missing] + beta @ (data_sorted[n_missing:] - mu[obs_idx])
    Sigma_missing_hat = cov_uu - beta @ cov_oo @ beta.T    
    
    pm.Deterministic('mu_missing_hat', mu_missing_hat)
    pm.Deterministic('Sigma_missing_hat', Sigma_missing_hat)
    
    # Make a random variable 
    marginal_missing = pm.MvNormal('marginal_missing', mu=mu_missing_hat, cov=Sigma_missing_hat)
    
    # Sample the new random variable, but note that we ***pass the OLD idata***!!!
    idata_marginal = pm.sample_posterior_predictive(idata, var_names=['marginal_missing'])

Basically, what we want to do is take the computational graph we had before (the one that estimated \mu and \Sigma in the first place) and extend it to include some new math. We are absolutely allowed to do this in PyMC, provided we follow a couple rules:

  1. Give the new model a new name: with pm.Model(coords=coords) as marginal_model
  2. Recycle the variable names of anything we want to save. Note that kept the names of the random variables cov, mu, etc. This is because when we forward sample the model, we want PyMC to use draws of the original model for these values. More on this later.
  3. Don’t call pm.sample, but instead call pm.sample_posterior_predictive, providing the OLD idata !

(3) is the key. Everything we add is going to just be deterministic computations of the random variables that are already living in idata. So, when PyMC does pm.sample_posterior_predictive, it will go look for the names of the variables living in idata, and replace random variables in our graph with those samples. Since we followed rule (2) and recycled the names, it will find and replace exactly what we want. All the quantities downstream of those nodes will be computed as we want, no need to juggle named dimensions. By the way, following rule (1) lets us recycle names in the first place (otherwise it will throw an error that the variable was already declared).

Why is this code so much better? Well, I think you agree it’s much easier to read. But we also don’t have to do any numpy sampling, it’s all automatic. Here’s the same plot again, as a one-liner:

idata_marginal.posterior_predictive.marginal_missing.plot.hist(bins=100);

image

But what’s even better is that we can now get conditional predictions for any combination of observed/missing we want, without making a new model, by using pm.set_data. Here’s an example:

with marginal_model:
    pm.set_data({'obs_data':[MISSING_VALUE, MISSING_VALUE, 500]})

    # Still (always!) passing the original idata
    idata_marginal_2 = pm.sample_posterior_predictive(idata, var_names=['marginal_missing'])

missing_idx = idata_marginal_2.posterior_predictive.marginal_missing_dim_2.values
fig, ax = plt.subplots(1, missing_idx.shape[0], figsize=(14,4), dpi=144)
for d in missing_idx:
    idata_marginal_2.posterior_predictive.marginal_missing.sel(marginal_missing_dim_2=d).plot.hist(bins=100, ax=fig.axes[d]);

So very quickly we can get the conditional predicted values for sensors 1 and 2, given that sensor 3 observes 500. Fast and easy.

3 Likes