How to implement Hierarchical Gaussian Process in PyMC?

I think I’m getting closer. What I ended up doing is porting the code from GPy. But I have absolutely no confidence that this is correct. Still the result did look nice

A bit of context. The input X must include grouping columns (indicate by extra_dims). So the entire input space need to be concatenate together to form X input. Now, I only tested for 2 levels, but 3 levels should be fine (I think)

Also the sampling performance might be hit-or-miss. I tried increase n_grps = 4. The NUTS sampling rate slowed down from ~3 minutes to over 2-3 hours. I’m not sure what is happening here

If anyone is familiar with HGP, please check my work

Setup Code

%config InlineBackend.figure_format = 'retina'
import GPy
import pymc as pm
import arviz as az
import numpy as np
import pytensor.tensor as pt
import matplotlib.pyplot as plt
from pymc.gp.util import plot_gp_dist

az.style.use("arviz-darkgrid")

np.random.seed(1337)
n_grps = 3
colors = ["lightcoral", "firebrick", "maroon"] 
T = np.linspace(0, 20, 300).reshape(-1, 1)

#construce kernel (covariance function) objects
kern_upper = GPy.kern.Matern32(input_dim=1, variance=1.0, lengthscale=3.0)
kern_lower = GPy.kern.Matern32(input_dim=1, variance=0.25, lengthscale=5.0)

#compute the covariance matrices
K_upper = kern_upper.K(T)
K_lower = kern_lower.K(T)

g = np.random.multivariate_normal(np.zeros(len(T)), K_upper)
f = []
for r in range(n_grps):
    f.append(np.random.multivariate_normal(g, K_lower))

max_samples = np.clip(np.random.uniform(0, 0.2*len(T), size=n_grps).astype(np.int32), 1, len(T))
selected_indices = [np.sort(np.random.permutation(len(T))[:max_samples[i]]) for i in range(n_grps)]

X_i = [T[selected_indices[i]] for i in range(n_grps)]
y_i = [f[i][selected_indices[i]] for i in range(n_grps)]

plt.plot(T, g, color="k")
for i in range(n_grps):
    plt.plot(T, f[i], "--", color=colors[i])
    plt.plot(X_i[i], y_i[i], "o", color=colors[i])
plt.show()

X, y = [], []
for i in range(len(X_i)):
    X.append(np.concatenate([X_i[i], i*np.ones_like(X_i[i])], axis=-1))
X = np.concatenate(X)
y = np.concatenate(y_i)

Custom Covariance

def index_to_slices(index):
    # From https://github.com/SheffieldML/GPy/blob/devel/GPy/util/multioutput.py
    if len(index)==0:
        return[]

    #contruct the return structure
    ind = np.asarray(index,dtype=np.int)
    ret = [[] for i in range(ind.max()+1)]

    #find the switchpoints
    ind_ = np.hstack((ind,ind[0]+ind[-1]+1))
    switchpoints = np.nonzero(ind_ - np.roll(ind_,+1))[0]

    [ret[ind_i].append(slice(*indexes_i)) for ind_i,indexes_i in zip(ind[switchpoints[:-1]],zip(switchpoints,switchpoints[1:]))]
    return ret

class Hierarchical(pm.gp.cov.Covariance):
    # Not sure if this should be `pm.gp.cov.Combination`
    def __init__(self, kernels):
        assert len(kernels) >= 2
        assert all([k.input_dim == kernels[0].input_dim for k in kernels])
        assert all([all(k.active_dims == kernels[0].active_dims) for k in kernels])
        input_max = max([k.input_dim for k in kernels])
        
        super().__init__(
            input_dim=kernels[0].input_dim,
            active_dims=kernels[0].active_dims
        )
        self.kernels = kernels
        self.extra_dims = range(input_max, input_max + len(kernels)-1)
    
    def call_kern(self, kernel, X, Xs):
        if isinstance(kernel, pm.gp.cov.Combination):
            return kernel(X, Xs)
        return kernel.full(X, Xs)
    
    def full(self, X, Xs=None):
        # Base kernel
        kernel = self.call_kern(self.kernels[0], X, Xs)
        
        X_slices = [
            index_to_slices(X[:, i]) for i in self.extra_dims
        ]
        
        if Xs is None:
            for k, slices_k in zip(self.kernels[1:], X_slices):
                for slices_i in slices_k:
                    for s in slices_i:
                        kernel = pt.set_subtensor(kernel[s, s], kernel[s, s] + self.call_kern(k, X[s], None))
        else:
            Xs_slices = [
                index_to_slices(Xs[:,i]) for i in self.extra_dims
            ]
            for k, slices_k1, slices_k2 in zip(self.kernels[1:], X_slices, Xs_slices):
                for slices_i, slices_j in zip(slices_k1, slices_k2):
                    for s, ss in zip(slices_i, slices_j):
                        kernel = pt.set_subtensor(kernel[s, ss], kernel[s, ss] + self.call_kern(k, X[s], Xs[ss]))
        return kernel
    
    def diag(self, X):
        return pt.diag(self.full(X, None))

Modeling

with pm.Model() as model:
    # Upper Covariance
    upper_ell = pm.Gamma("upper_ell", alpha=2, beta=1)
    upper_eta = pm.HalfNormal("upper_eta", sigma=5)
    upper_cov = upper_eta**2 * pm.gp.cov.ExpQuad(1, upper_ell, active_dims=[0])
    
    # Lower Covariance
    lower_ell = pm.Gamma("lower_ell", alpha=2, beta=1)
    lower_eta = pm.HalfNormal("lower_eta", sigma=5)
    lower_cov = lower_eta**2 * pm.gp.cov.ExpQuad(1, lower_ell, active_dims=[0])
    
    # Combined Covariance
    h_cov = Hierarchical([upper_cov, lower_cov])
    gp = pm.gp.Latent(cov_func=h_cov)
    
    f = gp.prior("f", X=X)
    sigma = pm.HalfNormal("sigma", sigma=2.0)
    nu = 1 + pm.Gamma(
        "nu", alpha=2, beta=0.1
    )  # add one because student t is undefined for degrees of freedom less than one
    y_ = pm.StudentT(f"y", mu=f, lam=(1.0/sigma), nu=nu, observed=y)

    idata = pm.sample(100, tune=100, chains=2, cores=6)

    # Predictions
    f_preds = [gp.conditional(f"f_preds_{i}", np.concatenate([T, i * np.ones_like(T)], axis=1), jitter=1e-4) for i in range(n_grps)]
    ppc = pm.sample_posterior_predictive(
        idata.posterior, var_names=sum([[f"f_preds_{i}"] for i in range(n_grps)], [])
    )

Ploting

f_posts = az.extract(ppc.posterior_predictive, var_names=[f"f_preds_{i}" for i in range(n_grps)]).transpose("sample", ...)

fig = plt.figure(figsize=(10, 4))
ax = fig.gca()

color_palette = ["Reds", "Greens", "Blues"]
for i in range(n_grps):
    plot_gp_dist(ax, f_posts[f"f_preds_{i}"], T, palette=color_palette[i], samples_alpha=0.5)

# axis labels and title
plt.xlabel("X")
plt.ylabel("True f(x)")
plt.title("Posterior distribution over $f(x)$ at the observed values")
plt.show()

fig, axes = plt.subplots(n_grps + 1, 1, figsize=(10, 4*(n_grps + 1)), sharex=True)
color_palette = ["Reds", "Greens", "Blues"]
colors = ["red", "green", "blue"]
axes[0].plot(T, g, color="k", label="True Upper")
for i in range(n_grps):
    axes[0].plot(T, f_true[i], "-", color=colors[i], label=f"True Lower[{i}]")
    axes[0].plot(
        T_new[(T_new >= T.min()) & (T_new <= T.max())],
        f_posts[f"f_preds_{i}"].mean(["sample"]).to_numpy()[((T_new >= T.min()) & (T_new <= T.max())).reshape(-1)],
        "--", color=colors[i], alpha=0.5
    )
    axes[0].legend(loc="best")
    plot_gp_dist(axes[i + 1], f_posts[f"f_preds_{i}"], T_new, palette=color_palette[i])
    axes[i + 1].plot(T, f_true[i], color="dodgerblue", label="True Lower")
    axes[i + 1].plot(X_i[i], y_i[i], "ok", label="Samples")
    axes[i + 1].legend(loc="best")

plt.suptitle("Diagnosis Plot")
plt.show()