HSGP by levels of a categorical covariate

Ok this is not the best code I wrote but hopefully it gives you a starter. It matches the first model in the example notebook, where the covariance function is shared by all the groups.

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pytensor.tensor as pt

data = pd.read_csv("https://raw.githubusercontent.com/bambinos/bambi/main/tests/data/gam_data.csv")
data["fac"] = pd.Categorical(data["fac"])

by_values = data["fac"].to_numpy()
x2_centered = (
    data["x2"] - data[["x2", "fac"]].groupby("fac").transform("mean")["x2"]
).to_numpy()

by_levels = [1, 2, 3]

m_value = 12
L_value = [0.78, 0.74, 0.73]  # I took it from Bambi internals

coords = {
    "gp_weights_dim": np.arange(12),
    "fac": ["1", "2", "3"],
    "__obs__": np.arange(len(data))
}


indexes_to_unsort = by_values.argsort(kind="mergesort").argsort(kind="mergesort")
phi_list, sqrt_psd_list = [], []

with pm.Model(coords=coords) as model:
    hsgp_sigma = pm.Exponential("hsgp_sigma", lam=3)
    hsgp_ell = pm.Exponential("hsgp_ell", lam=3)
    
    cov_func = hsgp_sigma**2 * pm.gp.cov.ExpQuad(1, hsgp_ell)

    for i, level in enumerate(by_levels):
        hsgp = pm.gp.HSGP(
            m=[m_value],
            L=[L_value[i]],
            drop_first=False,
            cov_func=cov_func,
        )

        # Data has to be 2d
        phi, sqrt_psd = hsgp.prior_linearized(x2_centered[by_values == level][:, None])
        sqrt_psd_list.append(sqrt_psd)
        phi_list.append(phi.eval())

    sqrt_psd = pt.stack(sqrt_psd_list, axis=1)

    coeffs_raw = pm.Normal("hsgp_weights_raw", dims=("gp_weights_dim", "fac"))
    coeffs = pm.Deterministic("hsgp_weights", coeffs_raw * sqrt_psd, dims=("gp_weights_dim", "fac"))

    contribution_list = []
    for i in range(len(by_levels)):
        contribution_list.append(phi_list[i] @ coeffs[:, i])
    # We need to unsort the contributions so they match the original data
    contribution = pt.concatenate(contribution_list)[indexes_to_unsort]

    hsgp = pm.Deterministic("hsgp", contribution, dims="__obs__")


    # Parameters for the likelihood
    mu = hsgp # they match in this
    sigma = pm.HalfNormal("sigma", sigma=1)

    pm.Normal("y", mu=mu, sigma=sigma, dims="__obs__", observed=data["y"])


model.to_graphviz()

with model:
    idata = pm.sample(chains=2, target_accept=0.95, random_seed=121195)
az.plot_trace(
    idata, 
    var_names=["hsgp_ell", "hsgp_sigma", "sigma"], 
    backend_kwargs={"layout": "constrained"}
);

and those posteriors look like the ones you can find in the example