MarginalSparse with low sigma as a workaround for sparse Latent GP

I have a large dataset (essentially a timeseries) with >20,000 points (already after decimating, cannot go below that) and I would like to use GP priors for some of its properties (mean, variance of data ~ iid Normal samples). However, pymc3 currently doesn’t have a sparse Latent implementation. Therefore, I have been using a MarginalSparse with low sigma as a workaround like so:

with pm.Model() as model_fluc:
    T = pm.Gamma('T', alpha=3, beta=1, testval=3)
    s = pm.Gamma('s', alpha=2, beta=0.5)
    a = pm.Gamma('a', alpha=1.5, beta=1)
    cov = s**2*pm.gp.cov.RatQuad(1, a, T)
    gp = pm.gp.MarginalSparse(cov_func=cov)
    var_log = gp.marginal_likelihood('var_log', X, Xu, y=None, sigma=1e-2, is_observed=False)
    var = pm.Deterministic('var', tt.exp(var_log))
    obs = pm.Normal('obs', sd=tt.sqrt(var), observed=y_fluc)  # y_fluc has 0 mean

My question is how valid and robust or usable such a workaround/hack is. I get NaN/inf errors during sampling or MAP search with lower values of sigma. The MAP var estimate agrees reasonably well with a rolling var estimate, but sampling is really slow.

The slowness could be (partially) caused by the consecutive data samples being correlated. But to incorporate that I would need a non-sparse Latent GP with a short-wavelength kernel and that won’t fit even into 100 GB of RAM (tried it) …

1 Like

To me this is a perfectly valid thing to do. @bwengals is actually working on extending GP further this summer and will include a Latent Sparse GP.

Glad to hear I’m not completely wrong, thank you.
Any suggestions why it becomes unstable with low sigma values?

It’s not surprising that it does just from how the math works out, for instance, check out this line where sigma2 is in the denominator.

This may make it into PyMC3s codebase, but in the meantime you could adapt this gist to get your example working.

1 Like

Thank you for confirming my suspicions. I suppose I cannot use .predict() anyways though.

I tried to start implementing a LatentSparse class according to you notebook @bwengals, but I’m a little lost with the proper use of mean_func and mean_total in .conditional(). See my first attempts here.
I think I was able to get quite close to the .prior() though. Could you please point me to some literature so that I can properly implement the mean usage?

Thats awesome! Go ahead and do the PR, and we can discuss there.

For future reference, the discussion on the LatentSparse implementation will continue in this PR.

1 Like