Hierarchical GPs with varying lengthscales

Hi folks,

I’m actually testing a new implementation of hierarchical GPs that I hope will be faster than how they’re typically implemented, but I need a reasonable typical implementation to compare against. I came up with this:

import numpy as np
import pymc as pm
import pytensor.tensor as pt

n_x = 10 # number of evaluation points for the GP
n_id = 20 # number of independent GPs
x = np.linspace(0, 1, n_x) # GP grid

with pm.Model() as model:
	x_Data = pm.Data("x_Data", x[:,None])
	y_Data = pm.Data("y_Data", np.random.randn(n_id,n_x))
	lengthscale = pm.Weibull("lengthscale", alpha=2.0, beta=1.0, shape=n_id)
	f_list = []
	cov_list = []
	gp_list = []
	for i_id in range(n_id):
		cov_list.append(pm.gp.cov.ExpQuad(1, ls=lengthscale[i_id]))
		gp_list.append(pm.gp.Latent(cov_func=cov_list[i_id]))
		f_list.append(gp_list[i_id].prior(f"f_{i_id}", X=x_Data))
	f = pt.stack(f_list, axis=0)  # [n_x, n_id]
	y = pm.Normal("y", mu=f, sigma=1, observed=y_Data)

But this seems to take the compiler a while to sort out before even sampling. I’m a bit of a pymc newb still, so I figured I should check that I’m not missing a more efficient “typical” approach.

Thoughts?

Oh, after some playing, this seems to avoid the long compile times at least:


def get_f(i_id,lengthscale,z,x_Data,jitter):
	"""
	Compute f = L @ z for a single i_id, using the given lengthscale.
	"""
	cov_i = pm.gp.cov.ExpQuad(1, ls=lengthscale).full(x_Data)
	L_i = pt.linalg.cholesky(cov_i + jitter)
	f_i = L_i @ z[:,i_id]  # [n_x,n_x] @ [n_x,1] = [n_x,1]
	return f_i


jitter = 1e-6 * pt.eye(n_x)
with pm.Model() as model:
	x_Data = pm.Data("x_Data", x[:,None])
	y_Data = pm.Data("y_Data", np.random.randn(n_x,n_id))
	lengthscale = pm.Weibull("lengthscale", alpha=2.0, beta=1.0, shape=n_id)
	z = pm.Normal("z", mu=0, sigma=1, shape=(n_x,n_id))
	f, _ = scan(
		fn = get_f
		, sequences = [pt.arange(n_id)]
		, non_sequences = [lengthscale,z,x_Data,jitter]
	)
	y = pm.Normal("y", mu=f, sigma=1, observed=y_Data)

If anyone has any further suggestions to make this a reasonable baseline against which I’ll be comparing my alternative, let me know!

What’s the current “baseline” way to do it that you’re looking to improve?

Edit: I couldn’t run your scan method as posted, it gave a shape error. The following worked:

import pytensor
from pymc.gp.util import stabilize

def get_f(lengthscale, z, data):
    cov_i = pm.gp.cov.ExpQuad(1, ls=lengthscale).full(data)
    L_i = pt.linalg.cholesky(stabilize(cov_i))
    f_i = L_i @ z
    
    return f_i

with pm.Model() as model:
    x_Data = pm.Data("x_Data", x[:,None])
    y_Data = pm.Data("y_Data", np.random.randn(n_x,n_id))
    
    lengthscale = pm.Weibull("lengthscale", alpha=2.0, beta=1.0, shape=n_id)
    zs = pm.Normal("z", mu=0, sigma=1, shape=(n_id, n_x))
    
    f, _ = pytensor.scan(
        fn = get_f,
        sequences = [lengthscale, zs],
        non_sequences=[x_Data]
    )
    
    y = pm.Normal("y", mu=f.T, sigma=1, observed=y_Data)
    idata = pm.sample(compile_kwargs={'mode':'NUMBA'})

You could also try vectorizing the creation of f instead of scanning. This amounts to a compute/memory tradeoff. If you have enough memory to vectorize, it will be faster than scanning, since the computation isn’t done sequentially.

with pm.Model() as model:
    x_Data = pm.Data("x_Data", x[:,None])
    y_Data = pm.Data("y_Data", np.random.randn(n_x,n_id))
    
    lengthscale = pm.Weibull("lengthscale", alpha=2.0, beta=1.0, shape=n_id)
    zs = pm.Normal("z", mu=0, sigma=1, shape=(n_id, n_x))
    
    f = pt.vectorize(get_f, signature=('(),(x),(x,o)->(x)'))(lengthscale, zs, x_Data)
    
    y = pm.Normal("y", mu=f.T, sigma=1, observed=y_Data)
    idata = pm.sample(compile_kwargs={'mode':'FAST_RUN'})

In this case I got better sampling speed with scan in the numba backend than vectorize in the C or numba backend. But vectorize was faster than scan in the C backend.

1 Like

That was my question :slight_smile:

I’ll post my faster method when I have it more rigorously quantified, but it involves computing the cholesky decomposition more efficiently through knowledge of the symmetric-Toeplitz structure of the correlation matrix, then computing the matrix-multiply more efficiently through knowledge that it’s a lower-tri and using the foward-substitution algorithm.

Thanks for the heads-up on vectorize(); I hadn’t come across it yet and should clearly add it to my toolbox.

Help wanted on this issue :slight_smile:

Ah, I should have anticipated someone would be on this already! In my testing so far, I’m not seeing any speedups in the case of decomposing-and-mat-mult-ing a single matrix (indeed it’s much slower), but it wouldn’t surprise me if that’s merely because cholesky() has the gradient hand-coded. I am seeing solid speedups for the hierarchical case, scaling with n_id as expected, but the speedup diminishes and reverses as n_x increases, yielding a fairly narrow range (values from 2-8 ish) of utility. Still, the original problem that inspired this had n_x=5 and n_id=1e6, where this should save days on compute, so it maybe still worth working out the full integration in pytensor.

My implementation is below, as well as comparison against the usual approach; what would you suggest would be the first steps to integrating this?

# exec(open("gp_as_sem_test.py").read())
use_numba = False  # Set to True if you want to use Numba for compilation
n_x = 2
n_id = 100


import numpy as np
import pymc as pm
import pytensor.tensor as pt
from pytensor.scan import scan
import os
import time
import humanize

np.set_printoptions(linewidth=np.inf)

class time_tracker():
	def __init__(self, name=None):
		self.name = name
		self.start_time = time.time()
		print(f"{self.name}...")

	def elapsed(self):
		self.duration = time.time() - self.start_time
		print(f"{self.name} duration: {humanize.precisedelta(self.duration, minimum_unit='seconds')}")


def sample(model):
	n_chains = int(os.cpu_count()/2)
	sampling_tracker = time_tracker("Sampling")
	with model:
		posterior_trace = pm.sample(
			cores = n_chains
			, chains = n_chains
			, nuts_sampler = 'pymc'
			, seed = 112358
			, compile_kwargs = {'mode':'NUMBA'} if use_numba else {'mode':'FAST_RUN'}	
		)
		sampling_tracker.elapsed()
	return posterior_trace


def bareiss_cholesky_toeplitz_batch_pt(r_batch,inds,eps=1e-12):
	"""
	Compute Cholesky factors for many symmetric Toeplitz matrices
	using the fraction‐free Bareiss algorithm.

	Parameters
	----------
	r_batch : TensorVariable, shape (q, n)
	    Each row i is the first column (and row) of the i-th Toeplitz matrix.
	    Must satisfy r_batch[:, 0] > 0.
	eps : float, optional
	    Tiny constant to guard against division by zero or negative sqrt
	    due to numerical drift.

	Returns
	-------
	L_final : TensorVariable, shape (q, n, n)
	    Lower‐triangular factors so that for each i:
	        L_final[i] @ L_final[i].T == toeplitz_matrix_from(r_batch[i]).
	"""
	# ------------------------------------------------------------------------
	# 1) Extract static batch dimensions (must be known at graph‐build time)
	# ------------------------------------------------------------------------
	q, n = r_batch.shape  # q = number of matrices, n = matrix size

	# ------------------------------------------------------------------------
	# 2) Build the batched Toeplitz tensor M of shape (q, n, n)
	#    via advanced indexing with a constant index‐difference array.
	# ------------------------------------------------------------------------
	# Broadcast r_batch over a pre-computed index‐difference array: M[i] = toeplitz(r_batch[i])
	M0 = r_batch[:, inds]  # shape (q, n, n)

	# ------------------------------------------------------------------------
	# 3) Initialize the output L and the “previous pivot” vector
	# ------------------------------------------------------------------------
	# L will be filled column by column; start with zeros
	L0 = pt.zeros((q, n, n), dtype=r_batch.dtype)
	# pivot_prev starts as all ones (since the Bareiss algorithm uses 1 as the
	# previous pivot before the first step)
	pivot_prev0 = pt.ones((q,), dtype=r_batch.dtype)

	# ------------------------------------------------------------------------
	# 4) Define the scan step: one iteration of k=0..n-1
	# ------------------------------------------------------------------------
	def step(k, M_prev, pivot_prev, L_prev):
		"""
		Single Bareiss pass at column k for all matrices in the batch.
		Inputs:
		  k          : scalar index (0 <= k < n)
		  M_prev     : tensor (q, n, n), current Schur‐complement blocks
		  pivot_prev : vector (q,), previous pivot values
		  L_prev     : tensor (q, n, n), accumulated L up to column k-1
		Outputs:
		  M_next     : updated Schur‐complement for next iteration
		  pivot      : current pivot (q,)
		  L_next     : L_prev with column k filled in
		"""
		# 4a) Extract the current pivot entries for each matrix
		pivot = M_prev[:, k, k]  # shape (q,)

		# 4b) Compute the denominator = sqrt(pivot * pivot_prev) with jitter
		denom = pt.sqrt(pt.clip(pivot * pivot_prev, eps, pt.inf))  # shape (q,)

		# 4c) Fill in L's diagonal entry L[:,k,k] = pivot/denom
		L1 = pt.set_subtensor(L_prev[:, k, k], pivot / denom)

		# 4d) Fill in the below‐diagonal entries L[:,k+1:,k] = M[:,k+1:,k]/denom
		# Note: if k == n-1, the slice is empty, so this is a no-op
		L1 = pt.set_subtensor(
		    L1[:, k+1:, k],
		    M_prev[:, k+1:, k] / denom[:, None]
		)

		# 4e) Compute the next Schur‐complement block via the Bareiss update:
		#      M_ij ← (M_ij * pivot - M_i,k * M_k,j) / pivot_prev
		num = (
		    M_prev[:, k+1:, k+1:] * pivot[:, None, None]
		    - M_prev[:, k+1:, k][:, :, None] * M_prev[:, k, k+1:][:, None, :]
		)
		M1 = pt.set_subtensor(
		    M_prev[:, k+1:, k+1:],
		    num / pivot_prev[:, None, None]
		)

		# 4f) Zero out the eliminated column below the pivot (optional)
		# M1 = pt.set_subtensor(M1[:, k+1:, k], 0.0)

		# Return new states for the next iteration
		return M1, pivot, L1

	# ------------------------------------------------------------------------
	# 5) Execute scan over k = 0,1,...,n-1
	# ------------------------------------------------------------------------
	# sequences               = the k‐indices
	# outputs_info            = initial values for [M, pivot_prev, L]
	result,updates = scan(
	    fn=step
	    , sequences=pt.arange(n)
	    , outputs_info=[M0, pivot_prev0, L0]
		, return_list=True
	)
	M_seq, piv_seq, L_seq = result

	# ------------------------------------------------------------------------
	# 6) Extract the final L from the last scan output
	#    L_seq has shape (n, q, n, n); we want the last time step
	# ------------------------------------------------------------------------
	L_final = L_seq[-1]  # shape (q, n, n)

	return L_final


def get_f_multi_pt(L, z):
	"""
	Compute batched forward substitution f = L⋅z for many
	lower‐triangular matrices L, using PyTensor’s scan().

	Parameters
	----------
	L : TensorVariable, shape (n_id, n_x, n_x)
	    Batch of lower‐triangular matrices.
	z : TensorVariable, shape (n_id, n_x)
	    Batch of right‐hand‐side vectors.

	Returns
	-------
	f_final : TensorVariable, shape (n_id, n_x)
	    Solutions f such that for each i:
	        f[i] = L[i] @ z[i]
	"""
	# Extract dimensions (symbolic)
	n_id, n_x = z.shape

	# 1) Compute f[:,0] directly: first component is just z[:,0]
	f0 = z[:, 0]  # shape (n_id,)

	# 2) Define the scan step for i=1..n_x-1
	def step(i, L, z):
		"""
		Compute f[:,i] = sum_j=0..i L[:,i,j] * z[:,j] for each batch.
		Inputs:
		  i : scalar index into the second dimension
		  L : (n_id, n_x, n_x) batch of matrices
		  z : (n_id, n_x) batch of vectors
		Returns:
		  f_i : (n_id,) the i-th component of f for each batch
		"""
		# Slice L[:,i,:i+1]  → shape (n_id, i+1)
		# Slice z[:,:i+1]    → shape (n_id, i+1)
		# Elementwise multiply and sum across the last axis
		f_i = pt.sum(L[:, i, :i+1] * z[:, :i+1], axis=1)
		return f_i  # shape (n_id,)

	# 3) Run scan over indices 1..n_x-1
	#    sequences       = the time‐steps i
	#    non_sequences   = L and z (they remain constant)
	#    outputs_info    = None (we only collect the returned f_i’s)
	f_seq, _ = scan(
		fn=step,
		sequences=pt.arange(1, n_x),
		non_sequences=[L, z]
	)

	# f_seq has shape (n_x-1, n_id); transpose to (n_id, n_x-1)
	f_rest = f_seq.T

	# 4) Concatenate f0 and the rest to form full f of shape (n_id, n_x)
	f_final = pt.concatenate([f0[:, None], f_rest], axis=1)

	return f_final


x = np.linspace(0, 1, n_x)
dists = (x[1:]-x[0])**2
neg_half_dists = np.repeat(-0.5 * dists[None,:], axis=0, repeats=n_id)

jitter = 1e-6
one_plus_jitter = 1.0 + jitter
one_plus_jitter_rep_n_id_2D = np.array([one_plus_jitter]*n_id).reshape((n_id,1))
toeplitz_indices = np.abs(np.arange(n_x)[None, :] - np.arange(n_x)[:, None])


with pm.Model() as new_model:
	y_Data = pm.Data("y_Data", np.random.randn(n_id, n_x))  # Simulated observed data
	lengthscale_squared = pm.Weibull("lengthscale_squared", alpha=1.0, beta=1.0, shape=n_id)
	# r: implied correlation given lengthscale and distance grid
	r = pt.concatenate(
		[
			one_plus_jitter_rep_n_id_2D
			, pm.math.exp(
				neg_half_dists
				/ pt.repeat(lengthscale_squared[:,None], axis=1, repeats=n_x-1)  # [n_x, n_id]
			)
		]
		, axis = 1
	)
	L = bareiss_cholesky_toeplitz_batch_pt(r,toeplitz_indices)  # [n_id, n_x, n_x]
	z = pm.Normal("z", mu=0, sigma=1, shape=(n_id,n_x))  # [n_id, n_x]
	f = get_f_multi_pt(L, z)  # [n_id, n_x]
	y = pm.Normal("y", mu=f, sigma=1, observed=y_Data)

# new_model.debug(verbose=True)

# sample the prior
with new_model:
	prior_trace = pm.sample_prior_predictive(
		draws = 1
		, var_names = ["y"]
		, random_seed = 112358
	)

# Set the data for the model
with new_model:
	pm.set_data({"y_Data": prior_trace.prior_predictive.y.to_numpy().squeeze()})

# sample
new_posterior_trace = sample(new_model)



# For comparison, below is the same model but using cholesky() and @
from pymc.gp.util import stabilize

def get_f_i(i_id,lengthscale,z,x):
	"""
	Compute f = L @ z for a single individual, using the given lengthscale.
	"""
	cov = pm.gp.cov.ExpQuad(1, ls=lengthscale[i_id]).full(x)
	L = pt.linalg.cholesky(stabilize(cov))
	f = L @ z[:,i_id]  # [n_x,n_x] @ [n_x,1] = [n_x,1]
	return f


def get_f_vec(lengthscale,z,x):
	"""
	Compute f = L @ z for a single individual, using the given lengthscale.
	"""
	cov = pm.gp.cov.ExpQuad(1, ls=lengthscale).full(x)
	L = pt.linalg.cholesky(stabilize(cov))
	f = L @ z  # [n_x,n_x] @ [n_x,1] = [n_x,1]
	return f


x_2D = x[:,None]  # [n_x,1]
id_sequence = np.arange(n_id)
with pm.Model() as usual_model:
	y_Data = pm.Data("y_Data", np.random.randn(n_id,n_x))
	x_2D_Data = pm.Data("x_2D_Data", x_2D)  # [n_x,1]
	lengthscale = pm.Weibull("lengthscale", alpha=2.0, beta=1.0, shape=n_id)
	z = pm.Normal("z", mu=0, sigma=1, shape=(n_id,n_x))
	f = pt.vectorize(get_f_vec, signature=(f'(),({n_x}),({n_x},1)->({n_x})'))(lengthscale, z, x_2D_Data)
	# f, _ = scan(
	# 	fn = get_f_i
	# 	, sequences = [id_sequence]
	# 	, non_sequences = [lengthscale,z,x_2D]
	# )
	y = pm.Normal("y", mu=f, sigma=1, observed=y_Data)

# usual_model.debug()

# sample the prior
with usual_model:
	prior_trace = pm.sample_prior_predictive(
		draws = 1
		, var_names = ["y"]
		, random_seed = 112358
	)

# Set the data for the model
with usual_model:
	pm.set_data({"y_Data": prior_trace.prior_predictive.y.to_numpy().squeeze()})

# sample
usual_posterior_trace = sample(usual_model)

My gut reaction is that it’s typically shouldn’t be possible to beat numerical implementations in BLAS/LAPACK. I benchmarked your scan dot against a normal vectorized dot and didn’t find any speedup from using the scan:

from pytensor.graph.replace import vectorize_graph
from pytensor.compile.mode import get_default_mode

L = pt.dmatrix('L')
x = pt.dvector('x')

L_batched = pt.tensor('L_batched', shape=(None, None, None))
x_batched = pt.dmatrix('x_batched')

normal_dot = L @ x

fn_0 = pytensor.function([L_batched, x_batched], 
                         vectorize_graph(normal_dot, {L:L_batched, x:x_batched}))
fn_1 = pytensor.function([L_batched, x_batched], 
                         get_f_multi_pt(L_batched, x_batched))

n_id, n_x = int(1e6), 5
A_val = rng.normal(size=(n_x, n_x))
A_val = A_val @ A_val.T
L_val = linalg.cholesky(A_val, lower=True)

x_val = rng.normal(size=(n_x,))

L_batched_val = np.stack([L_val for _ in range(n_id)], axis=0)
x_batched_val = np.stack([x_val for _ in range(n_id)], axis=0)

%timeit fn_0(L_batched_val, x_batched_val)
# 19.9 ms ± 366 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit fn_1(L_batched_val, x_batched_val)
# 27.9 ms ± 589 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

My guess is that given your setup with a ton of really small problems (n_id >> n_x), you are either facing a memory bottleneck in the “normal” dot case, or you are paying overhead costs over and over going back and forth between python and compiled subroutines.

I also tested using trmv but I think this is always going to be slower in your case, since you’re dealing with lots of small matrices. If I were you the next thing I’d try is to think about forming a single huge matrix A = block_diag(Ls). That would be too huge to actually create, but if you think about it, the resulting matrix will be banded triangular, with L.shape[0] subdiagonals (assuming L is lower). So you could write that in banded form and try using dtbmv to do dtbmv(A, z_batched.ravel()).reshape(z_batched.shape). That might let you avoid all the back-and-forth that comes from your problem setup, while still exploiting the structure of L.

I’m not familiar with the Bareiss algorithm, but my quick search found that it was related to numerical precision, not necessarily to speed. Have you specifically benchmarked it against pt.linalg.cholesky?

In the code I posted, the second model uses pt.linalg.cholesky.

pt.linalg.cholesky performs a naive decomposition applicable to any matrix and is O(n^3) complexity.

The Bareiss algorithm is for decomposing Toeplitz matrices specifically and has O(n^2) complexity (and O(n^2) in memory). ^1

So, we should expect that the Bareiss algorithm will outperform traditional Cholesky for Toeplitz matrices, maybe with some overhead where it is only after the matrix reaches a particular size that the benefits of Bareiss appear, but with benefits scaling with increasing n_x thereafter.

In reality here, we have a very low-level-implemented (inc. explicit gradients) traditional cholesky and a very-high-level-implemented Bareiss, so when decomposing a single matrix I expected a slowdown with Bareiss. It’s possible that a high enough n_x would find a crossover point, but I didn’t think it was worthwhile exploring that.

I do think it would be worthwhile adding an explicit gradient for the Bareiss ^2 to see if that helps speed it up for the single matrix case, but I was actually more interested in how the Bareiss algorithm is structured such that vectorized computation of the decomposition of multiple matrices was easy to add. And when you set n_id>1 in the code I provided, we’re testing out that multiple-matrices scenario, where there is at least a window of values for n_x & n_id where the Bareiss algorithm is substantially faster than traditional cholesky.

Footnotes:
^1 There is a less memory-consumptive “Levinson-Durbin” algorithm (used in scipy.linalg.solve_toeplitz) that is O(n^2) in complexity and O(n) in memory, but LD is known to have numerical instability whereas Bareiss is provably stable. I actually started my explorations with the LD algorithm but found that even for low values of n_x I was seeing a concerning amount of accumulated error in the final terms of the decomposition, hence switching to Bareiss.

^2 I might end up first trying the “sub-quadratic” algorithm implemented in the SuperGauss package, which is O( n log n ) complexity and has the gradient already coded.

1 Like

Very cool, I didn’t know about these methods – thank you for the explanation!

It wouldn’t take much work to wrap scipy.linalg.solve_toeplitz into an Op and use it for both the forward and backward pass (the sensitivity equations of solve involve another solve against A, so you can take advantage of its structure both times). For that matter, I don’t think it would be too onerous to take the code you have for Bareiss and make a COp, so that you avoid the scan and stay in C-land for the entire computation. Then you can scan (or vectorize) over your huge stack of matrices. A numba implementation would be even easier – you could do a basic Op with python loop for the Bareiss algorithm, then add a numba dispatch to an njit version that pytensor could use when sampling your model.

One other thing to think about is that in pytensor we do optimizations when you have solve(A, b) and solve(A,c) by rewriting this into lu_and_piv = lu_factor(A) --> lu_solve(lu_and_piv, b), lu_solve(lu_and_piv, c). In the positive definite case, this turns into cho_solve(L, b), cho_solve(L, c). Reusing the decomposition always comes up when you are computing values and gradients, so I wonder if it’s something that could also be exploited in this case. If I understand you code well, you’re doing a special case cholesky decomposition? You could probably just hand that off to cho_solve and skip the forward step. My uneducated guess is that this would end up with better gradients.