I am trying to compare different GP models using az.loo
or az.waic
. In PyMC these methods work by computing the pointwise log-likelihood with pm.compute_log_likelihood
function.
The problem i am facing is that with GP models the compute_log_likelihood does not return the values for each element.
I was expecting the loglikelihood group to have shape (chain, draw, y_dim_0) instead it has (chain,draw).
Is there something i am missing?
I will leave a reproducible example from the Marginal Likelihood Implementation notebook.
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import scipy as sp
# set the seed
np.random.seed(1)
n = 100 # The number of data points
X = np.linspace(0, 10, n)[:, None] # The inputs to the GP, they must be arranged as a column vector
# Define the true covariance function and its parameters
ell_true = 1.0
eta_true = 3.0
cov_func = eta_true**2 * pm.gp.cov.Matern52(1, ell_true)
# A mean function that is zero everywhere
mean_func = pm.gp.mean.Zero()
# The latent function values are one sample from a multivariate normal
# Note that we have to call `eval()` because PyMC3 built on top of Theano
f_true = np.random.multivariate_normal(
mean_func(X).eval(), cov_func(X).eval() + 1e-8 * np.eye(n), 1
).flatten()
# The observed data is the latent function plus a small amount of IID Gaussian noise
# The standard deviation of the noise is `sigma`
sigma_true = 2.0
y = f_true + sigma_true * np.random.randn(n)
with pm.Model() as model:
ell = pm.Gamma("ell", alpha=2, beta=1)
eta = pm.HalfCauchy("eta", beta=5)
cov = eta**2 * pm.gp.cov.Matern52(1, ell)
gp = pm.gp.Marginal(cov_func=cov)
sigma = pm.HalfCauchy("sigma", beta=5)
y_ = gp.marginal_likelihood("y", X=X, y=y, sigma=sigma)
with model:
marginal_post = pm.sample(nuts_sampler="pymc", idata_kwargs={"log_likelihood": True})
az.loo(marginal_post)
The loo computation of course provide the following Userwarning:
‘The point-wise LOO is the same with the sum LOO, please double check the Observed RV in your model to make sure it returns element-wise logp.’