# Struggling with Latent GP in an inverse problem

Here is my model for some kind of inverse Laplace transform:

omega = np.linspace(0.5, 3.0, 100)
Kernel = np.exp(-omega[None,:] * X1[:,None])
parity = np.arange(Kernel.shape)%2
coefs = 4*parity + 2*(1-parity)
coefs = 1
coefs[-1] = 1
SimpsonKernel = Kernel*coefs[None,:]/3.0

InversePos1_model = pm.Model()
with InversePos1_model:
scale = 5
ell1 = pm.Gamma("ell1", alpha=0.5, beta=4)
eta1 = pm.LogNormal("eta1", mu=scale, sigma=1e-2*scale)
cov_func1 = eta1**2 * pm.gp.cov.Matern52(1, ell1)

gp = pm.gp.Latent(cov_func=cov_func1)
logf = gp.prior("logf", X=omega[:,None])
f = pm.Deterministic('f',pt.exp(logf))
g = pt.dot(SimpsonKernel, f)

Y_obs = pm.Normal("Y_obs", mu=g, sigma=stdY, observed=Yplateau)


And my data looks essentially like that, dominated by a few exponentials:

Yplateau = Y[:,tmin:tmax+1]
guess = np.zeros_like(omega)
guess[int((0.8-omega)/(omega-omega))] = 7e1
guess[int((1.3-omega)/(omega-omega))] = 1e3
guess[int((2.1-omega)/(omega-omega))] = 2e3
plt.plot(Y,label='data')
plt.plot(np.sum(guess*Kernel,axis=1),label='guess')
plt.yscale('log')
plt.legend()
plt.show() I can get a vaguely decent MAP, although for some reason it sometimes crashes in a Cholesky step:

with InversePos1_model:
InversePos1_MAP = pm.find_MAP(tol=1e-15)
plt.plot(omega,InversePos1_MAP['f']) But I could not get ADVI to work (if I allow my cov_func to be a sum for two different smoothing radii then it gives something, but it diverges at large number of iterations, giving a huge peak at the far-right of f).

I couldn’t get much out of NUTS neither (which is what I would really like to have). Not only it is very slow (dominated by Cholesky steps which scale in n^3 or so with the size of \omega), which I could accept, but it never thermalises, ergodicity is broken by eye and autocorrelations are huge, so obviously rhat values are terrible.
I tried many combinations of arguments and many home-made ways to initialise NUTS, but here’s one example:

Any suggestion?
Do you think I’m building a particularly sick model?

A simpler model with f\sim Gamma instead of a GP (i.e. some kind of white noise, no smoothing) was giving some results (one sharp peak on the left and then a broak peak on the right, probably mixing several smaller peaks within error bars), but it had huge error bars.

Is it possible for you to post some runnable code? Am missing at least X1. Would be able to give some better advice.

But in your data / guess plot, I assume you’re basically using a GP to fit that data? That data will be tricky because it’s so smooth. You’ll have to set your priors on eta and ell very carefully. There is a huge range of ell values, basically out to infinity that will give you a smooth almost linear line so it’ll really help the sampler if you can constrain that so it doesn’t have to explore that whole space. That’ll help you use NUTS.

Another thing that’ll help is that since you’re using gp.Latent with a 1D stationary covariance, Matern52, it’s possible for you to use pm.gp.HSGP approximation. You’ll have to check the docstring of it for advice setting m and c, or specifically the text around Fig.6 here.

1 Like

I finally had time to go back to this problem. I prepared a self-contained version with synthetic data so that I can share a full code. Here it is:

import sys
print(sys.version)
import platform
import numpy as np
from scipy import linalg
from scipy import ndimage
from scipy.stats import norm
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
from pytensor.tensor import slinalg
from pymc.sampling.jax import sample_numpyro_nuts
from pymc.sampling.jax import sample_blackjax_nuts
print(f"Running on PyMC v{pm.__version__}")


3.9.16 | packaged by conda-forge | (main, Feb 1 2023, 21:39:03)
[GCC 11.3.0]
Running on PyMC v5.6.0

# Parameters to generate synthetic data
tsrc=31
(ncfg, ncorr, tvals) = (1000, 1, 96)
tmin = tsrc+1
tmax = tvals-20
omega = np.linspace(0.5, 3.0, 20) # 20 is extremely coarse, the larger the better

platlength = tmax - tmin +1
X1all = np.arange(-tmin,tvals-tmin)
X1 = np.arange(0,tmax-tmin+1)
X = X1[:,None]
KernelAll = np.exp(-omega[None,:] * X1all[:,None])
Kernel = np.exp(-omega[None,:] * X1[:,None])
energytruth = np.zeros_like(omega)
true_ground = omega[int((0.8-omega)/(omega-omega))]
print('True ground state: ',true_ground)
energytruth[int((0.8-omega)/(omega-omega))] = 7e1
energytruth[int((1.3-omega)/(omega-omega))] = 1e3
energytruth[int((2.1-omega)/(omega-omega))] = 2e3
timetruth = np.sum(energytruth*KernelAll,axis=1)
data = timetruth*norm.rvs(loc=1.0,scale=1e-1,size=(ncfg, ncorr, tvals))
chan=0
Y = data[:,2*chan,:]
Yplateau = Y[:,tmin:tmax+1]
Ymean = Yplateau.mean(axis=0)
stdY = Yplateau.std(axis=0)

fig, ax = plt.subplots(1,2,figsize=(12,4))
ax.plot(timetruth[tmin:tmax+1],label='truth')
ax.scatter(np.arange(0,tmax+1-tmin),Ymean,label='synthetic data',c='r')
ax.set_yscale('log')
ax.legend()
ax.scatter(np.arange(0,tmax+1-tmin),Ymean/timetruth[tmin:tmax+1],label='noise')
ax.legend()
plt.show()


White_model = pm.Model()
with White_model:
eta = pm.LogNormal('eta', mu=1.0, sigma=3.0)
dw = omega-omega
default = np.ones_like(omega)
f = pm.Gamma('f', alpha=eta*dw+1, beta=eta*dw/default, shape=omega.shape)

g = pm.Deterministic('g', pt.dot(Kernel, f) )

Y_obs = pm.Normal("Y_obs", mu=g, sigma=stdY, observed=Yplateau)

with White_model:
White_MAP = pm.find_MAP(tol=1e-15)
plt.plot(omega,White_MAP['f'])
plt.show() with White_model:
White_trace = pm.sample(model=White_model,
target_accept=0.9,
#nuts_sampler='numpyro',
draws=1000,tune=1000,chains=4,
initvals=White_MAP
)
display(az.summary(White_trace,round_to=6))
az.plot_trace(White_trace)
plt.show()


stack = White_trace.posterior.stack(sample=("chain", "draw"))
f_post = stack['f']
f_mean = f_post.mean(dim=("sample")).values
f_std = f_post.std(dim=("sample")).values
f_minus = f_mean - f_std
f_plus = f_mean + f_std
hdi = az.hdi(White_trace,hdi_prob=.68)
f_lower=hdi['f'].sel(hdi='lower').values
f_higher=hdi['f'].sel(hdi='higher').values

fig, ax = plt.subplots(dpi=200)
sns.lineplot(x=omega, y=f_mean, color='r', label='central', ax=ax)
ax.fill_between(
x=omega,
y1=f_lower,
y2=f_higher,
color='r',
alpha=0.2,
label='68% HDI'
)
plt.legend()
plt.show()


# Isolate manually the first peak
f_cut = f_post.values[omega<1.0]

# Now find the exact position of the peak
MDtracepos = np.argmax(f_cut,axis=0)
MDtrace = omega[MDtracepos]
MD_mean = MDtrace.mean()
MD_std = MDtrace.std()
print(MD_mean, '+/-', MD_std)
MD_low,MD_high = az.hdi(MDtrace,hdi_prob=0.68)
print(MD_low,MD_high)


# Use Composite Simpson's 1/3 rule, instead of the rectangle rule
parity = np.arange(Kernel.shape)%2
print(parity)
coefs = 4*parity + 2*(1-parity)
coefs = 1
coefs[-1] = 1
SimpsonKernel = Kernel*coefs[None,:]/3.0

InversePos_model = pm.Model()
with InversePos_model:
scale = np.log(np.abs(np.log(400)))
ell = pm.Gamma("ell", alpha=0.6, beta=12)
eta = pm.LogNormal("eta", mu=scale, sigma=1e-2*scale)

cov_func = eta**2 * pm.gp.cov.Matern52(1, ell)
gp = pm.gp.Latent(cov_func=cov_func,mean_func=pm.gp.mean.Constant(1.0e-6))
logf = gp.prior("logf", X=omega[:,None])
f = pm.Deterministic("f", pt.exp(logf) )
# This potential leads to zero acceptance in Numpyro
#pm.Potential("positivity",pm.math.log(pm.math.switch(f>pt.zeros_like(f), 1, 0)))

g = pt.dot(SimpsonKernel, f)

Y_obs = pm.Normal("Y_obs", mu=g, sigma=stdY, observed=Yplateau)

# Works only in small dimension
with InversePos_model:
InversePos_MAP = pm.find_MAP(tol=1e-15)
plt.plot(omega,InversePos_MAP['f'])
plt.show() with InversePos_model:
InversePos_trace = pm.sample(model=InversePos_model,
target_accept=0.99,
tune=500,draws=1000,chains=4,
initvals=InversePos_MAP
)
display(az.summary(InversePos_trace,round_to=6))
display(az.summary(InversePos_trace,round_to=6,var_names=["ell","eta"]))
print("Initial (MAP): ",InversePos_MAP['ell'],InversePos_MAP['eta'])
az.plot_trace(InversePos_trace,var_names=["ell","eta"])
plt.show()


stack = InversePos_trace.posterior.stack(sample=("chain", "draw"))
f_post = stack['f']
f_mean = f_post.mean(dim=("sample")).values
f_std = f_post.std(dim=("sample")).values
f_minus = f_mean - f_std
f_plus = f_mean + f_std
hdi = az.hdi(InversePos_trace,hdi_prob=.68)
f_lower=hdi['f'].sel(hdi='lower').values
f_higher=hdi['f'].sel(hdi='higher').values

fig, ax = plt.subplots(dpi=200)
sns.lineplot(x=omega, y=f_mean, color='r', label='central', ax=ax)
ax.fill_between(
x=omega,
y1=f_lower,
y2=f_higher,
color='r',
alpha=0.2,
label='68% HDI'
)
plt.legend()
plt.show()


So with this fake data (in which the noise is small and constant) the results are actually quite good (the last two peaks are a bit too high since peaks [0.8,1.3,2.1] should have height [70,1000,2000], and maybe my priors were too tight, but the order of magnitude is correct). I hadn’t tried to use synthetic data, I should have. So the issue is how it behaves with real-world noisy data.

I will try the HSGP as you suggested. I don’t remember whether I already did.

But in your data / guess plot, I assume you’re basically using a GP to fit that data? That data will be tricky because it’s so smooth. You’ll have to set your priors on eta and ell very carefully. There is a huge range of ell values, basically out to infinity that will give you a smooth almost linear line

Not quite. I apply the GP in the energy space (where I want my result) but that plot was in the time space (where I have my data). Those are conjugate variables and you use the kernel to get from one to the other. In the time space the data is dominated by the ground state exponential (but the other two have sizeable contributions hidden by the logarithmic scale), so that the plot becomes a straight line at asymptotically large time (on short times you see a bit of a curvature). In the energy space we have the other extreme (and that could be difficult for another reason) : the function represented as a GP is a sum of a few Dirac distributions. I hope this explanation plus the code above clarifies things.

Then, if I change omega to 50 instead of 20, the “white noise” model fits in 12 minutes to give this decent result:

while the GP model is taking a long time:

But the HSGP(m=25,c=4) has interesting results in 11 minutes, so it’s definitely worth exploring further. HSGP could be the solution to my problem: