I’m trying to fit a simple GP model with exponential covariance (damped random walk model) with timescale l
. However, the MAP values do not seem reasonable, and fail to recover the input values from simulated time series.
The code for the model is:
def fit_drw(t, y, yerr, cadence, baseline, amplitude, precision):
import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt
with pm.Model() as model:
# damped random walk
l = pm.Uniform("l", lower=np.sqrt(1/2), upper=np.sqrt(1e8*baseline/2))
# 2l^2 = tau
# l = sqrt(tau/2)
sigma_drw = pm.Uniform("sigma_drw", lower=0.1*precision, upper=10*amplitude)
cov = 2*sigma_drw**2 * pm.gp.cov.Exponential(1, l)
gp_drw = pm.gp.Marginal(cov_func=cov)
# The Gaussian process is a sum of these three components
gp = gp_drw
# Since the normal noise model and the GP are conjugates, we use `Marginal` with the `.marginal_likelihood` method
X = t[:, None]
y_ = gp.marginal_likelihood("y", X=X, y=y, noise=yerr)
mp = pm.find_MAP() #start={'l': 100, 'sigma_drw': amplitude})
# Predict
tpred = np.linspace(np.min(t), np.max(t)+400, 1000)
Xpred = tpred[:, None]
mu, var = gp.predict(Xpred, point=mp, diag=True)
sd = np.sqrt(var)
plt.figure(figsize=(8,4))
plt.plot(tpred, mu, "dodgerblue", lw=3)
plt.fill_between(tpred, mu-sd, mu+sd, color="dodgerblue", alpha=0.2)
plt.errorbar(t, y, yerr=yerr, color="k", linestyle='none', ms=3, alpha=1)
plt.show()
print(mp)
return
The MAP prediction looks reasonable when plotted below. However, the MAP value of l
is huge! Why?
{'l_interval__': array(12.60262633), 'sigma_drw_interval__': array(1.12011405), 'l': array(499998.31842261), 'sigma_drw': array(8.61098201)}
For completion, the coed to simulate the time series is:
from astroML.time_series import generate_damped_RW
dt = 50
baseline = 5000
t = np.arange(0, baseline, dt)
y = generate_damped_RW(t, tau=250, xmean=20, SFinf=0.3, z=0.0)
fit_drw(t, y, 0.005*y, cadence=dt, baseline=baseline, amplitude=np.max(y)-np.min(y), precision=0.005)
The details of the generate_damped_RW function should not be a concern, other than the input timescale is given by 2l^2 = tau
.
Thanks!