HSGP by levels of a categorical covariate

I am attempting to fit a model with a HSGP contribution grouped by levels of a categorical variable, with a different HSGP for each category. Exactly the same as done in this Bambi tutorial, but I need to accomplish the same in PyMC v5. Can anyone please point me in the right direction to accomplish this?

Ideally I would use the prior_linearized version of HSGP with the categorical covariate, so I can later use set_data() to make predictions on out of sample data.

The outputs should look like the below, which a separate HSGP for each categorical variable:

Bambi has a very clean syntax to accomplish this using the by argument as such: hsgp(x2, by=fac, m=12, c=1.5). I’m looking to do the same thing, but in PyMC.

1 Like

Check out this example. It’s pretty new but should answer your question.

1 Like

I came up with a solution to implement HSGP by categorical covariates in PyMC, and the result (below) appears to match the Bambi tutorial:

My model looks like this:

# Prepare the data
fac = sorted(data['fac'].unique())

# Create coordinate reference
coords = {
    'fac': fac,
    'obs_id': np.arange(len(data))
}

# Create mapping dictionaries
fac_lookup = {fac: i for i, fac in enumerate(fac)}

with pm.Model(coords=coords) as model:
    
    # Data containers
    fac_idx = pm.Data("fac_idx", [fac_lookup[i] for i in data['fac']])
    x2 = pm.Data("x2", data['x2'].values)
    y = pm.Data("y", data['y'].values)
    
    # HSGP priors
    length_scale = pm.Exponential("ell", lam=3, dims='fac')
    amplitude = pm.HalfCauchy('amplitude', 10, dims='fac')
    
    # Instantiate category specific HSGPs
    f_s = []
    for i in range(len(fac)):
        cov = (amplitude[i]**2) * pm.gp.cov.ExpQuad(1, ls=length_scale[i])
        gp = pm.gp.HSGP(m=[12], c=1.5, cov_func=cov)
        f = gp.prior(f"HSGP_{fac[i]}", X=x2[:, None])
        f_s.append(f)
        
    # Combine category-specific effects using stack and advanced indexing
    stacked_effects = pt.stack(f_s)  # Shape: (n_procedures, n_observations)
    alpha_x2 = pm.Deterministic("alpha_x2", 
                                 stacked_effects[fac_idx, pt.arange(stacked_effects.shape[1])])
    
    sigma = pm.HalfNormal("sigma", 1)
    
    # Likelihood
    y_hat = pm.Normal("y_hat", mu=alpha_x2, sigma=sigma, observed=y, dims='obs_id')

The model architecture is pretty ugly:

Compared to the model architecture in the Bambi tutorial, which is very clean:

But the results seem to match.

I would be very appreciative if you have any feedback on my approach above @bwengals @tcapretto. Thank you.

2 Likes

That’s right, it’s OK to use for-loops to build each GP one by one, but that’s not what Bambi does under the hood.

You can simplify this example to not be hierarchical to get a similar looking model as Bambi.

Hi @wgeary!

I’m happy the example in Bambi is helpful. Let me try to give you some pointers to where you can read about our implementation. But before that, let me also tell you the implementation is quite general because it has to work with many different configurations, so your code should be different.

First, when you do hsgp(x2, by=fac, m=12, c=1.5) in a Bambi model formula, the information is passed to the transformation implemented here

Imagine that as doing some pre-processing steps and making sure the values passed are like the ones one expect.

Then, Bambi creates internal Term objects. There is a specific type for HSGP terms (because they have specific metadata). You can fand it here

This should not be extremely relevant as it’s only handling metadata. Well, I just saw it also computes L, so you may find it relevant.

Finally, this is where Bambi translates the Bambi HSGP Term into PyMC.

The most relevant part is in the .build() method.

1 Like

Ok this is not the best code I wrote but hopefully it gives you a starter. It matches the first model in the example notebook, where the covariance function is shared by all the groups.

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pytensor.tensor as pt

data = pd.read_csv("https://raw.githubusercontent.com/bambinos/bambi/main/tests/data/gam_data.csv")
data["fac"] = pd.Categorical(data["fac"])

by_values = data["fac"].to_numpy()
x2_centered = (
    data["x2"] - data[["x2", "fac"]].groupby("fac").transform("mean")["x2"]
).to_numpy()

by_levels = [1, 2, 3]

m_value = 12
L_value = [0.78, 0.74, 0.73]  # I took it from Bambi internals

coords = {
    "gp_weights_dim": np.arange(12),
    "fac": ["1", "2", "3"],
    "__obs__": np.arange(len(data))
}


indexes_to_unsort = by_values.argsort(kind="mergesort").argsort(kind="mergesort")
phi_list, sqrt_psd_list = [], []

with pm.Model(coords=coords) as model:
    hsgp_sigma = pm.Exponential("hsgp_sigma", lam=3)
    hsgp_ell = pm.Exponential("hsgp_ell", lam=3)
    
    cov_func = hsgp_sigma**2 * pm.gp.cov.ExpQuad(1, hsgp_ell)

    for i, level in enumerate(by_levels):
        hsgp = pm.gp.HSGP(
            m=[m_value],
            L=[L_value[i]],
            drop_first=False,
            cov_func=cov_func,
        )

        # Data has to be 2d
        phi, sqrt_psd = hsgp.prior_linearized(x2_centered[by_values == level][:, None])
        sqrt_psd_list.append(sqrt_psd)
        phi_list.append(phi.eval())

    sqrt_psd = pt.stack(sqrt_psd_list, axis=1)

    coeffs_raw = pm.Normal("hsgp_weights_raw", dims=("gp_weights_dim", "fac"))
    coeffs = pm.Deterministic("hsgp_weights", coeffs_raw * sqrt_psd, dims=("gp_weights_dim", "fac"))

    contribution_list = []
    for i in range(len(by_levels)):
        contribution_list.append(phi_list[i] @ coeffs[:, i])
    # We need to unsort the contributions so they match the original data
    contribution = pt.concatenate(contribution_list)[indexes_to_unsort]

    hsgp = pm.Deterministic("hsgp", contribution, dims="__obs__")


    # Parameters for the likelihood
    mu = hsgp # they match in this
    sigma = pm.HalfNormal("sigma", sigma=1)

    pm.Normal("y", mu=mu, sigma=sigma, dims="__obs__", observed=data["y"])


model.to_graphviz()

with model:
    idata = pm.sample(chains=2, target_accept=0.95, random_seed=121195)
az.plot_trace(
    idata, 
    var_names=["hsgp_ell", "hsgp_sigma", "sigma"], 
    backend_kwargs={"layout": "constrained"}
);

and those posteriors look like the ones you can find in the example

2 Likes

A big belated thank you @tcapretto for everything that you’ve provided above. You’ve given me plenty to chew on. Extremely helpful - thanks again!

1 Like