How To Use LOOCV With pm.gp.MarginalApprox?

Hi, can someone explain how to get the LOOCV score from a Sparse Marginal GP? I can get it successfully with the Dense Marginal GP:

with pm.Model() as marginal_model:
    
    # Priors. 
    ls = pm.Gamma("ls", alpha=LS_PRIOR_ALPHA, beta=LS_PRIOR_BETA)
    eta = pm.HalfCauchy("eta", beta=ETA_PRIOR_BETA)

    # Zero mean.
    mean = pm.gp.mean.Zero()
    # Matern 5/2.
    cov = eta**2 * pm.gp.cov.Matern52(2, ls)

    # Choose marginal GP.
    # This integrates out the latent function, constraining us to a normally-distributed response.
    gp = pm.gp.Marginal(mean_func=mean, cov_func=cov)
    
    sigma = pm.HalfNormal("sigma", SIGMA_PRIOR)
    marginal = gp.marginal_likelihood("y", X=X_train_scaled_downsampled, y=y_train_downsampled, sigma=sigma)

    marginal_post = pm.sample(draws=1000, tune=500, chains=4, nuts_sampler="numpyro", return_inferencedata=True,)

    log_likelihood = pm.compute_log_likelihood(marginal_post)
    loo_marginal = pm.loo(log_likelihood)
    print(f"LOO score (ELPD): {loo_marginal.elpd_loo}")
    print(f"Effective number of parameters (p_loo): {loo_marginal.p_loo}")

But when I run the same code with Sparse Marginal GP, I get “TypeError: log likelihood not found in inference data object.” :open_mouth:

with pm.Model() as sparse_model:

    # Same as gp.Marginal model except for uses the entire dataset.
    ls = pm.Gamma("ls", alpha=LS_PRIOR_ALPHA, beta=LS_PRIOR_BETA)
    eta = pm.HalfCauchy("eta", beta=ETA_PRIOR_BETA)

    cov = eta**2 * pm.gp.cov.Matern52(2, ls)
    
    gp = pm.gp.MarginalApprox(cov_func=cov, approx="FITC")

    NUM_INDUCING_POINTS = 150
    Xu = pm.gp.util.kmeans_inducing_points(NUM_INDUCING_POINTS, X_train_scaled)
    
    sigma = pm.HalfNormal("sigma", SIGMA_PRIOR)
    y_ = gp.marginal_likelihood("y", 
                                X=X_train_scaled, 
                                Xu=Xu, 
                                y=y_train, 
                                sigma=sigma,
                                jitter=1e-5) # Keep the matrix happy.
    
    sparse_post = pm.sample(draws=1000, tune=500, chains=4, nuts_sampler="numpyro", return_inferencedata=True,)

    """
    This actually doesn't work. PyMC does not seem to support LOOCV with the sparse model.
    
    log_likelihood = pm.compute_log_likelihood(sparse_post)
    loo_sparse = pm.loo(log_likelihood)
    print(f"LOO score (ELPD): {loo_sparse.elpd_loo}")
    print(f"Effective number of parameters (p_loo): {loo_sparse.p_loo}")

    Gives error at this line: loo_sparse = pm.loo(log_likelihood)
    TypeError: log likelihood not found in inference data object
    """

That error makes sense because for this model the marginal likelihood is actually a lower bound to the true marginal likelihood (if you’re using the default VFE approximation). Instead of using MvNormal it’s calculated using pm.Potential. That means you’ll have to calculate it from scratch.

I’d have to double check, but I think the DTC and FITC approximations have a proper likelihood, not a lower bound, that they use. That might help with the calculation.

Ok, thanks. I did use approx=“FITC” in my code so I assume that was an oversight when you said FITC might have a proper likelihood.

Fundamentally, I don’t understand much of this theory so I’m not equipped to implement something like calculating the likelihood from scratch.

Let’s say I just want to compare these two models and decide which is better. I think the long-term answer is: “go learn the theory.” Check…understood. But what is my short-term answer? Are there any easy one-liners I can add to get some kind of model comparison? My only idea for a short-term fix would be implement a k-fold (with small k!!!) manual cross validation. Obviously that would blow up the run time because MCMC is slow.

It might not be as bad as you expect. gp.MarginalApprox uses the implementation described in this paper, particularly equation 5. pm.Potential is used to implement the likelihood because of the trace term. FITC doesn’t have that term, so you could modify the code to use a MvNormal in that case without too much extra work.

Ok very cool thanks. Looks like all the ingredients are here.