Ah, I should have anticipated someone would be on this already! In my testing so far, I’m not seeing any speedups in the case of decomposing-and-mat-mult-ing a single matrix (indeed it’s much slower), but it wouldn’t surprise me if that’s merely because cholesky()
has the gradient hand-coded. I am seeing solid speedups for the hierarchical case, scaling with n_id
as expected, but the speedup diminishes and reverses as n_x
increases, yielding a fairly narrow range (values from 2-8 ish) of utility. Still, the original problem that inspired this had n_x=5 and n_id=1e6, where this should save days on compute, so it maybe still worth working out the full integration in pytensor.
My implementation is below, as well as comparison against the usual approach; what would you suggest would be the first steps to integrating this?
# exec(open("gp_as_sem_test.py").read())
use_numba = False # Set to True if you want to use Numba for compilation
n_x = 2
n_id = 100
import numpy as np
import pymc as pm
import pytensor.tensor as pt
from pytensor.scan import scan
import os
import time
import humanize
np.set_printoptions(linewidth=np.inf)
class time_tracker():
def __init__(self, name=None):
self.name = name
self.start_time = time.time()
print(f"{self.name}...")
def elapsed(self):
self.duration = time.time() - self.start_time
print(f"{self.name} duration: {humanize.precisedelta(self.duration, minimum_unit='seconds')}")
def sample(model):
n_chains = int(os.cpu_count()/2)
sampling_tracker = time_tracker("Sampling")
with model:
posterior_trace = pm.sample(
cores = n_chains
, chains = n_chains
, nuts_sampler = 'pymc'
, seed = 112358
, compile_kwargs = {'mode':'NUMBA'} if use_numba else {'mode':'FAST_RUN'}
)
sampling_tracker.elapsed()
return posterior_trace
def bareiss_cholesky_toeplitz_batch_pt(r_batch,inds,eps=1e-12):
"""
Compute Cholesky factors for many symmetric Toeplitz matrices
using the fraction‐free Bareiss algorithm.
Parameters
----------
r_batch : TensorVariable, shape (q, n)
Each row i is the first column (and row) of the i-th Toeplitz matrix.
Must satisfy r_batch[:, 0] > 0.
eps : float, optional
Tiny constant to guard against division by zero or negative sqrt
due to numerical drift.
Returns
-------
L_final : TensorVariable, shape (q, n, n)
Lower‐triangular factors so that for each i:
L_final[i] @ L_final[i].T == toeplitz_matrix_from(r_batch[i]).
"""
# ------------------------------------------------------------------------
# 1) Extract static batch dimensions (must be known at graph‐build time)
# ------------------------------------------------------------------------
q, n = r_batch.shape # q = number of matrices, n = matrix size
# ------------------------------------------------------------------------
# 2) Build the batched Toeplitz tensor M of shape (q, n, n)
# via advanced indexing with a constant index‐difference array.
# ------------------------------------------------------------------------
# Broadcast r_batch over a pre-computed index‐difference array: M[i] = toeplitz(r_batch[i])
M0 = r_batch[:, inds] # shape (q, n, n)
# ------------------------------------------------------------------------
# 3) Initialize the output L and the “previous pivot” vector
# ------------------------------------------------------------------------
# L will be filled column by column; start with zeros
L0 = pt.zeros((q, n, n), dtype=r_batch.dtype)
# pivot_prev starts as all ones (since the Bareiss algorithm uses 1 as the
# previous pivot before the first step)
pivot_prev0 = pt.ones((q,), dtype=r_batch.dtype)
# ------------------------------------------------------------------------
# 4) Define the scan step: one iteration of k=0..n-1
# ------------------------------------------------------------------------
def step(k, M_prev, pivot_prev, L_prev):
"""
Single Bareiss pass at column k for all matrices in the batch.
Inputs:
k : scalar index (0 <= k < n)
M_prev : tensor (q, n, n), current Schur‐complement blocks
pivot_prev : vector (q,), previous pivot values
L_prev : tensor (q, n, n), accumulated L up to column k-1
Outputs:
M_next : updated Schur‐complement for next iteration
pivot : current pivot (q,)
L_next : L_prev with column k filled in
"""
# 4a) Extract the current pivot entries for each matrix
pivot = M_prev[:, k, k] # shape (q,)
# 4b) Compute the denominator = sqrt(pivot * pivot_prev) with jitter
denom = pt.sqrt(pt.clip(pivot * pivot_prev, eps, pt.inf)) # shape (q,)
# 4c) Fill in L's diagonal entry L[:,k,k] = pivot/denom
L1 = pt.set_subtensor(L_prev[:, k, k], pivot / denom)
# 4d) Fill in the below‐diagonal entries L[:,k+1:,k] = M[:,k+1:,k]/denom
# Note: if k == n-1, the slice is empty, so this is a no-op
L1 = pt.set_subtensor(
L1[:, k+1:, k],
M_prev[:, k+1:, k] / denom[:, None]
)
# 4e) Compute the next Schur‐complement block via the Bareiss update:
# M_ij ← (M_ij * pivot - M_i,k * M_k,j) / pivot_prev
num = (
M_prev[:, k+1:, k+1:] * pivot[:, None, None]
- M_prev[:, k+1:, k][:, :, None] * M_prev[:, k, k+1:][:, None, :]
)
M1 = pt.set_subtensor(
M_prev[:, k+1:, k+1:],
num / pivot_prev[:, None, None]
)
# 4f) Zero out the eliminated column below the pivot (optional)
# M1 = pt.set_subtensor(M1[:, k+1:, k], 0.0)
# Return new states for the next iteration
return M1, pivot, L1
# ------------------------------------------------------------------------
# 5) Execute scan over k = 0,1,...,n-1
# ------------------------------------------------------------------------
# sequences = the k‐indices
# outputs_info = initial values for [M, pivot_prev, L]
result,updates = scan(
fn=step
, sequences=pt.arange(n)
, outputs_info=[M0, pivot_prev0, L0]
, return_list=True
)
M_seq, piv_seq, L_seq = result
# ------------------------------------------------------------------------
# 6) Extract the final L from the last scan output
# L_seq has shape (n, q, n, n); we want the last time step
# ------------------------------------------------------------------------
L_final = L_seq[-1] # shape (q, n, n)
return L_final
def get_f_multi_pt(L, z):
"""
Compute batched forward substitution f = L⋅z for many
lower‐triangular matrices L, using PyTensor’s scan().
Parameters
----------
L : TensorVariable, shape (n_id, n_x, n_x)
Batch of lower‐triangular matrices.
z : TensorVariable, shape (n_id, n_x)
Batch of right‐hand‐side vectors.
Returns
-------
f_final : TensorVariable, shape (n_id, n_x)
Solutions f such that for each i:
f[i] = L[i] @ z[i]
"""
# Extract dimensions (symbolic)
n_id, n_x = z.shape
# 1) Compute f[:,0] directly: first component is just z[:,0]
f0 = z[:, 0] # shape (n_id,)
# 2) Define the scan step for i=1..n_x-1
def step(i, L, z):
"""
Compute f[:,i] = sum_j=0..i L[:,i,j] * z[:,j] for each batch.
Inputs:
i : scalar index into the second dimension
L : (n_id, n_x, n_x) batch of matrices
z : (n_id, n_x) batch of vectors
Returns:
f_i : (n_id,) the i-th component of f for each batch
"""
# Slice L[:,i,:i+1] → shape (n_id, i+1)
# Slice z[:,:i+1] → shape (n_id, i+1)
# Elementwise multiply and sum across the last axis
f_i = pt.sum(L[:, i, :i+1] * z[:, :i+1], axis=1)
return f_i # shape (n_id,)
# 3) Run scan over indices 1..n_x-1
# sequences = the time‐steps i
# non_sequences = L and z (they remain constant)
# outputs_info = None (we only collect the returned f_i’s)
f_seq, _ = scan(
fn=step,
sequences=pt.arange(1, n_x),
non_sequences=[L, z]
)
# f_seq has shape (n_x-1, n_id); transpose to (n_id, n_x-1)
f_rest = f_seq.T
# 4) Concatenate f0 and the rest to form full f of shape (n_id, n_x)
f_final = pt.concatenate([f0[:, None], f_rest], axis=1)
return f_final
x = np.linspace(0, 1, n_x)
dists = (x[1:]-x[0])**2
neg_half_dists = np.repeat(-0.5 * dists[None,:], axis=0, repeats=n_id)
jitter = 1e-6
one_plus_jitter = 1.0 + jitter
one_plus_jitter_rep_n_id_2D = np.array([one_plus_jitter]*n_id).reshape((n_id,1))
toeplitz_indices = np.abs(np.arange(n_x)[None, :] - np.arange(n_x)[:, None])
with pm.Model() as new_model:
y_Data = pm.Data("y_Data", np.random.randn(n_id, n_x)) # Simulated observed data
lengthscale_squared = pm.Weibull("lengthscale_squared", alpha=1.0, beta=1.0, shape=n_id)
# r: implied correlation given lengthscale and distance grid
r = pt.concatenate(
[
one_plus_jitter_rep_n_id_2D
, pm.math.exp(
neg_half_dists
/ pt.repeat(lengthscale_squared[:,None], axis=1, repeats=n_x-1) # [n_x, n_id]
)
]
, axis = 1
)
L = bareiss_cholesky_toeplitz_batch_pt(r,toeplitz_indices) # [n_id, n_x, n_x]
z = pm.Normal("z", mu=0, sigma=1, shape=(n_id,n_x)) # [n_id, n_x]
f = get_f_multi_pt(L, z) # [n_id, n_x]
y = pm.Normal("y", mu=f, sigma=1, observed=y_Data)
# new_model.debug(verbose=True)
# sample the prior
with new_model:
prior_trace = pm.sample_prior_predictive(
draws = 1
, var_names = ["y"]
, random_seed = 112358
)
# Set the data for the model
with new_model:
pm.set_data({"y_Data": prior_trace.prior_predictive.y.to_numpy().squeeze()})
# sample
new_posterior_trace = sample(new_model)
# For comparison, below is the same model but using cholesky() and @
from pymc.gp.util import stabilize
def get_f_i(i_id,lengthscale,z,x):
"""
Compute f = L @ z for a single individual, using the given lengthscale.
"""
cov = pm.gp.cov.ExpQuad(1, ls=lengthscale[i_id]).full(x)
L = pt.linalg.cholesky(stabilize(cov))
f = L @ z[:,i_id] # [n_x,n_x] @ [n_x,1] = [n_x,1]
return f
def get_f_vec(lengthscale,z,x):
"""
Compute f = L @ z for a single individual, using the given lengthscale.
"""
cov = pm.gp.cov.ExpQuad(1, ls=lengthscale).full(x)
L = pt.linalg.cholesky(stabilize(cov))
f = L @ z # [n_x,n_x] @ [n_x,1] = [n_x,1]
return f
x_2D = x[:,None] # [n_x,1]
id_sequence = np.arange(n_id)
with pm.Model() as usual_model:
y_Data = pm.Data("y_Data", np.random.randn(n_id,n_x))
x_2D_Data = pm.Data("x_2D_Data", x_2D) # [n_x,1]
lengthscale = pm.Weibull("lengthscale", alpha=2.0, beta=1.0, shape=n_id)
z = pm.Normal("z", mu=0, sigma=1, shape=(n_id,n_x))
f = pt.vectorize(get_f_vec, signature=(f'(),({n_x}),({n_x},1)->({n_x})'))(lengthscale, z, x_2D_Data)
# f, _ = scan(
# fn = get_f_i
# , sequences = [id_sequence]
# , non_sequences = [lengthscale,z,x_2D]
# )
y = pm.Normal("y", mu=f, sigma=1, observed=y_Data)
# usual_model.debug()
# sample the prior
with usual_model:
prior_trace = pm.sample_prior_predictive(
draws = 1
, var_names = ["y"]
, random_seed = 112358
)
# Set the data for the model
with usual_model:
pm.set_data({"y_Data": prior_trace.prior_predictive.y.to_numpy().squeeze()})
# sample
usual_posterior_trace = sample(usual_model)