Improving design of conditional / predictive sampling from GPs

Right now, the predictive distribution of GPs is sampled using sample_ppc, since the conditional distribution is known given the samples from the gp. This is kind of awkward, and not done for other PyMC3 models. It’s awkward because the .conditional distribution has to be added to the model after the inference step, e.g.

with pm.Model() as marginal_gp_model:
    cov_func = pm.gp.cov.ExpQuad(1, lengthscales=1.0)
    gp = pm.gp.Marginal(cov_func=cov_func)
    sigma = pm.HalfNormal("sigma", sd=5)
    y_ = gp.marginal_likelihood("y", X=X, y=y, noise=sigma)

# inference step here
with model:
    trace = pm.sample(1000)

# then add conditional to model
with model:
    f_star = gp.conditional("f_star", Xnew=Xnew)

# then sample without NUTS using sample_ppc
with model:
    pred_samples = pm.sample_ppc(trace, vars=[f_star])

I’ve been trying to think of more PyMC3-ic ways to do this, without sampling the conditional using NUTS, which greatly slows down inference. Here are some thoughts:

  • Use an empty shared variable for Xnew, such that the covariance matrix of conditional will be empty. This errors of course. It would work kind of like how prediction with regression models works.
  • A Gibbs step method, so that .conditional is sampled directly, or a step method that just skips the .conditional RV when it is assigned to it.

Sampling from the GP conditional can be really slow, so its good to be able to completely skip it while NUTS is run. Any ideas or preferences? Or is it fine how it is?

Agree that conditional should be separate from inference. I am always a bit vary of putting prediction node into PyMC3 model, even when they are completely conditional the logp still add to the model logp sometimes.

I think an empty shared variable might be the most PyMC3-ic. The Gibbs step could potentially work as well, however, we might need a specific step method for sampling from the conditional. Overall tho, I actually like adding additional node to do sample_ppc after inference (I saw you doing it and use it myself almost immediately in couple of my problems), and it works pretty well (also it reduce the trace).