Thank you very much, I looked at pytensor
and it now works, with replacements of the both the prior
and conditional
methods.
I attached the class below and the code for sampling, should it ever be useful to anybody
import pytensor.tensor as pt
class SparseLatent:
def __init__(self, cov_func):
self.cov = cov_func
def prior(self, name, X, Xu):
Kuu = self.cov(Xu)
self.L = pt.linalg.cholesky(pm.gp.util.stabilize(Kuu))
# def prior(self, name, X, Xu):
# Kuu = self.cov(Xu)
# self.L = pm.gp.util.cholesky(pm.gp.util.stabilize(Kuu))
self.v = pm.Normal(f"u_rotated_{name}", mu=0.0, sigma=1.0, shape=len(Xu))
self.u = pm.Deterministic(f"u_{name}", pt.dot(self.L, self.v))
Kfu = self.cov(X, Xu)
# self.Kuiu = tt.slinalg.solve_upper_triangular(
# self.L.T, tt.slinalg.solve_lower_triangular(self.L, self.u)
# )
self.Kuiu = pt.slinalg.solve_triangular(
self.L.T, pt.slinalg.solve_triangular(self.L, self.u, lower = True),
lower = False
)
self.mu = pm.Deterministic(f"mu_{name}", pt.dot(Kfu, self.Kuiu))
return self.mu
def conditional(self, name, Xnew, Xu):
Ksu = self.cov(Xnew, Xu)
mus = pt.dot(Ksu, self.Kuiu)
tmp = pt.slinalg.solve_triangular(self.L, Ksu.T, lower = True)
Qss = pt.dot(tmp.T, tmp) # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
Kss = self.cov(Xnew)
#Lss = pm.gp.util.cholesky(pm.gp.util.stabilize(Kss - Qss))
Lss = pt.linalg.cholesky(pm.gp.util.stabilize(Kss - Qss))
mu_pred = pm.MvNormal(name, mu=mus, chol=Lss, shape=len(Xnew))
return mu_pred
# Explicitly specify inducing points by downsampling our input vector
Xu = X[1::2]
with pm.Model() as model_hts:
ℓ = pm.InverseGamma("ℓ", mu=ℓ_μ, sigma=ℓ_σ)
η = pm.Gamma("η", alpha=2, beta=1)
cov = η**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=ℓ)
μ_gp = SparseLatent(cov)
#μ_f = μ_gp.prior("μ", X_obs, Xu)
μ_f = μ_gp.prior("μ", X_obs, Xu)
σ_ℓ = pm.InverseGamma("σ_ℓ", mu=ℓ_μ, sigma=ℓ_σ)
σ_η = pm.Gamma("σ_η", alpha=2, beta=1)
σ_cov = σ_η**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=σ_ℓ)
lg_σ_gp = SparseLatent(σ_cov)
lg_σ_f = lg_σ_gp.prior("lg_σ_f", X_obs, Xu)
σ_f = pm.Deterministic("σ_f", pm.math.exp(lg_σ_f))
lik_hts = pm.Normal("lik_hts", mu=μ_f, sigma=σ_f, observed=y_obs_)
trace_hts = pm.sample(target_accept=0.95, return_inferencedata=True, random_seed=SEED, chains = 1)
with model_hts:
μ_pred = μ_gp.conditional("μ_pred", Xnew, Xu)
lg_σ_pred = lg_σ_gp.conditional("lg_σ_pred", Xnew, Xu)
samples_hts = pm.sample_posterior_predictive(trace_hts, var_names=["μ_pred", "lg_σ_pred"])