I think I’m getting closer. What I ended up doing is porting the code from GPy. But I have absolutely no confidence that this is correct. Still the result did look nice
A bit of context. The input X must include grouping columns (indicate by extra_dims
). So the entire input space need to be concatenate together to form X
input. Now, I only tested for 2
levels, but 3
levels should be fine (I think)
Also the sampling performance might be hit-or-miss. I tried increase n_grps = 4
. The NUTS sampling rate slowed down from ~3 minutes
to over 2-3 hours
. I’m not sure what is happening here
If anyone is familiar with HGP, please check my work
Setup Code
%config InlineBackend.figure_format = 'retina'
import GPy
import pymc as pm
import arviz as az
import numpy as np
import pytensor.tensor as pt
import matplotlib.pyplot as plt
from pymc.gp.util import plot_gp_dist
az.style.use("arviz-darkgrid")
np.random.seed(1337)
n_grps = 3
colors = ["lightcoral", "firebrick", "maroon"]
T = np.linspace(0, 20, 300).reshape(-1, 1)
#construce kernel (covariance function) objects
kern_upper = GPy.kern.Matern32(input_dim=1, variance=1.0, lengthscale=3.0)
kern_lower = GPy.kern.Matern32(input_dim=1, variance=0.25, lengthscale=5.0)
#compute the covariance matrices
K_upper = kern_upper.K(T)
K_lower = kern_lower.K(T)
g = np.random.multivariate_normal(np.zeros(len(T)), K_upper)
f = []
for r in range(n_grps):
f.append(np.random.multivariate_normal(g, K_lower))
max_samples = np.clip(np.random.uniform(0, 0.2*len(T), size=n_grps).astype(np.int32), 1, len(T))
selected_indices = [np.sort(np.random.permutation(len(T))[:max_samples[i]]) for i in range(n_grps)]
X_i = [T[selected_indices[i]] for i in range(n_grps)]
y_i = [f[i][selected_indices[i]] for i in range(n_grps)]
plt.plot(T, g, color="k")
for i in range(n_grps):
plt.plot(T, f[i], "--", color=colors[i])
plt.plot(X_i[i], y_i[i], "o", color=colors[i])
plt.show()
X, y = [], []
for i in range(len(X_i)):
X.append(np.concatenate([X_i[i], i*np.ones_like(X_i[i])], axis=-1))
X = np.concatenate(X)
y = np.concatenate(y_i)
Custom Covariance
def index_to_slices(index):
# From https://github.com/SheffieldML/GPy/blob/devel/GPy/util/multioutput.py
if len(index)==0:
return[]
#contruct the return structure
ind = np.asarray(index,dtype=np.int)
ret = [[] for i in range(ind.max()+1)]
#find the switchpoints
ind_ = np.hstack((ind,ind[0]+ind[-1]+1))
switchpoints = np.nonzero(ind_ - np.roll(ind_,+1))[0]
[ret[ind_i].append(slice(*indexes_i)) for ind_i,indexes_i in zip(ind[switchpoints[:-1]],zip(switchpoints,switchpoints[1:]))]
return ret
class Hierarchical(pm.gp.cov.Covariance):
# Not sure if this should be `pm.gp.cov.Combination`
def __init__(self, kernels):
assert len(kernels) >= 2
assert all([k.input_dim == kernels[0].input_dim for k in kernels])
assert all([all(k.active_dims == kernels[0].active_dims) for k in kernels])
input_max = max([k.input_dim for k in kernels])
super().__init__(
input_dim=kernels[0].input_dim,
active_dims=kernels[0].active_dims
)
self.kernels = kernels
self.extra_dims = range(input_max, input_max + len(kernels)-1)
def call_kern(self, kernel, X, Xs):
if isinstance(kernel, pm.gp.cov.Combination):
return kernel(X, Xs)
return kernel.full(X, Xs)
def full(self, X, Xs=None):
# Base kernel
kernel = self.call_kern(self.kernels[0], X, Xs)
X_slices = [
index_to_slices(X[:, i]) for i in self.extra_dims
]
if Xs is None:
for k, slices_k in zip(self.kernels[1:], X_slices):
for slices_i in slices_k:
for s in slices_i:
kernel = pt.set_subtensor(kernel[s, s], kernel[s, s] + self.call_kern(k, X[s], None))
else:
Xs_slices = [
index_to_slices(Xs[:,i]) for i in self.extra_dims
]
for k, slices_k1, slices_k2 in zip(self.kernels[1:], X_slices, Xs_slices):
for slices_i, slices_j in zip(slices_k1, slices_k2):
for s, ss in zip(slices_i, slices_j):
kernel = pt.set_subtensor(kernel[s, ss], kernel[s, ss] + self.call_kern(k, X[s], Xs[ss]))
return kernel
def diag(self, X):
return pt.diag(self.full(X, None))
Modeling
with pm.Model() as model:
# Upper Covariance
upper_ell = pm.Gamma("upper_ell", alpha=2, beta=1)
upper_eta = pm.HalfNormal("upper_eta", sigma=5)
upper_cov = upper_eta**2 * pm.gp.cov.ExpQuad(1, upper_ell, active_dims=[0])
# Lower Covariance
lower_ell = pm.Gamma("lower_ell", alpha=2, beta=1)
lower_eta = pm.HalfNormal("lower_eta", sigma=5)
lower_cov = lower_eta**2 * pm.gp.cov.ExpQuad(1, lower_ell, active_dims=[0])
# Combined Covariance
h_cov = Hierarchical([upper_cov, lower_cov])
gp = pm.gp.Latent(cov_func=h_cov)
f = gp.prior("f", X=X)
sigma = pm.HalfNormal("sigma", sigma=2.0)
nu = 1 + pm.Gamma(
"nu", alpha=2, beta=0.1
) # add one because student t is undefined for degrees of freedom less than one
y_ = pm.StudentT(f"y", mu=f, lam=(1.0/sigma), nu=nu, observed=y)
idata = pm.sample(100, tune=100, chains=2, cores=6)
# Predictions
f_preds = [gp.conditional(f"f_preds_{i}", np.concatenate([T, i * np.ones_like(T)], axis=1), jitter=1e-4) for i in range(n_grps)]
ppc = pm.sample_posterior_predictive(
idata.posterior, var_names=sum([[f"f_preds_{i}"] for i in range(n_grps)], [])
)
Ploting
f_posts = az.extract(ppc.posterior_predictive, var_names=[f"f_preds_{i}" for i in range(n_grps)]).transpose("sample", ...)
fig = plt.figure(figsize=(10, 4))
ax = fig.gca()
color_palette = ["Reds", "Greens", "Blues"]
for i in range(n_grps):
plot_gp_dist(ax, f_posts[f"f_preds_{i}"], T, palette=color_palette[i], samples_alpha=0.5)
# axis labels and title
plt.xlabel("X")
plt.ylabel("True f(x)")
plt.title("Posterior distribution over $f(x)$ at the observed values")
plt.show()
fig, axes = plt.subplots(n_grps + 1, 1, figsize=(10, 4*(n_grps + 1)), sharex=True)
color_palette = ["Reds", "Greens", "Blues"]
colors = ["red", "green", "blue"]
axes[0].plot(T, g, color="k", label="True Upper")
for i in range(n_grps):
axes[0].plot(T, f_true[i], "-", color=colors[i], label=f"True Lower[{i}]")
axes[0].plot(
T_new[(T_new >= T.min()) & (T_new <= T.max())],
f_posts[f"f_preds_{i}"].mean(["sample"]).to_numpy()[((T_new >= T.min()) & (T_new <= T.max())).reshape(-1)],
"--", color=colors[i], alpha=0.5
)
axes[0].legend(loc="best")
plot_gp_dist(axes[i + 1], f_posts[f"f_preds_{i}"], T_new, palette=color_palette[i])
axes[i + 1].plot(T, f_true[i], color="dodgerblue", label="True Lower")
axes[i + 1].plot(X_i[i], y_i[i], "ok", label="Samples")
axes[i + 1].legend(loc="best")
plt.suptitle("Diagnosis Plot")
plt.show()