Ok this is not the best code I wrote but hopefully it gives you a starter. It matches the first model in the example notebook, where the covariance function is shared by all the groups.
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
data = pd.read_csv("https://raw.githubusercontent.com/bambinos/bambi/main/tests/data/gam_data.csv")
data["fac"] = pd.Categorical(data["fac"])
by_values = data["fac"].to_numpy()
x2_centered = (
data["x2"] - data[["x2", "fac"]].groupby("fac").transform("mean")["x2"]
).to_numpy()
by_levels = [1, 2, 3]
m_value = 12
L_value = [0.78, 0.74, 0.73] # I took it from Bambi internals
coords = {
"gp_weights_dim": np.arange(12),
"fac": ["1", "2", "3"],
"__obs__": np.arange(len(data))
}
indexes_to_unsort = by_values.argsort(kind="mergesort").argsort(kind="mergesort")
phi_list, sqrt_psd_list = [], []
with pm.Model(coords=coords) as model:
hsgp_sigma = pm.Exponential("hsgp_sigma", lam=3)
hsgp_ell = pm.Exponential("hsgp_ell", lam=3)
cov_func = hsgp_sigma**2 * pm.gp.cov.ExpQuad(1, hsgp_ell)
for i, level in enumerate(by_levels):
hsgp = pm.gp.HSGP(
m=[m_value],
L=[L_value[i]],
drop_first=False,
cov_func=cov_func,
)
# Data has to be 2d
phi, sqrt_psd = hsgp.prior_linearized(x2_centered[by_values == level][:, None])
sqrt_psd_list.append(sqrt_psd)
phi_list.append(phi.eval())
sqrt_psd = pt.stack(sqrt_psd_list, axis=1)
coeffs_raw = pm.Normal("hsgp_weights_raw", dims=("gp_weights_dim", "fac"))
coeffs = pm.Deterministic("hsgp_weights", coeffs_raw * sqrt_psd, dims=("gp_weights_dim", "fac"))
contribution_list = []
for i in range(len(by_levels)):
contribution_list.append(phi_list[i] @ coeffs[:, i])
# We need to unsort the contributions so they match the original data
contribution = pt.concatenate(contribution_list)[indexes_to_unsort]
hsgp = pm.Deterministic("hsgp", contribution, dims="__obs__")
# Parameters for the likelihood
mu = hsgp # they match in this
sigma = pm.HalfNormal("sigma", sigma=1)
pm.Normal("y", mu=mu, sigma=sigma, dims="__obs__", observed=data["y"])
model.to_graphviz()
with model:
idata = pm.sample(chains=2, target_accept=0.95, random_seed=121195)
az.plot_trace(
idata,
var_names=["hsgp_ell", "hsgp_sigma", "sigma"],
backend_kwargs={"layout": "constrained"}
);
and those posteriors look like the ones you can find in the example