Possible ways to speed up spatio-temporal GP modeling

Hi All,

I’ve been working on adding more features to my PyMC model for a couple of months now. I feel that I’m making progress, especially with the advanced features available in PyMC.

I’m now modeling spatio-temporal satellite data, and the model is shown below. While the model seems to generate the expected posterior distribution, it runs very slowly.

Interestingly, without optimizing the length scale (i.e., when using a fixed value for the length scale), the JAX-enabled Gaussian process finishes in about 10 minutes. However, once I optimize the length scale parameters (for both space and time), the processing time shoots up to more than 1 day on a CPU and over 8 hours on Colab’s single T4 GPU with JAX enabled.

Given that my total number of observations is around 700, the long processing time isn’t entirely unexpected. However, the duration seems excessive, even for a JAX-based simulation. Using multiple GPUs only reduces the time to about 4 hours (i.e., half of the 8 hours) if I use 1500 samples (with 1000 warm-up samples) for each chain.

I would greatly appreciate any advice on improving the speed of the code.

Thank you.


def ST_model ():
    with pm.Model() as model:

        normal_dist = pm.Normal.dist(mu=mu_lambda, sigma=sigma_lambda)
        Lambda = pm.Truncated("Lambda", normal_dist, shape=m, lower=0, upper=10)

        #===============================================================================
        # Mean function
        #===============================================================================
        mu = LinearMean(K = K, Lambda = Lambda)

        #===============================================================================
        # Kernel variance
        #===============================================================================
        var = pm.Exponential("kernel_var",  2)

        #===============================================================================
        # Kernel length
        #===============================================================================
        length_T = pm.Exponential("kernel_length_T", 1/(7/(31*NUM_MONTHS)))
        length_S = pm.Gamma("kernel_length_S", 0.5, 0.2)
        
        #===============================================================================
        # Noise
        #===============================================================================
        noise = pm.HalfCauchy ("noise", 1, shape = 1)
        
        #===============================================================================
        # Covariance function
        # Note we are not using kronecker product
        #===============================================================================
        # Spatial covariance
        spatial_cov = pm.gp.cov.Matern52(input_dim=3, ls=[length_S, length_S], active_dims=[0,1])
        
        # Temporal covariance function, with kernel variance
        temporal_cov = var * pm.gp.cov.ExpQuad(input_dim=3, ls=length_T, active_dims=[2])

        # Combined spatio-temporal covariance. 
        cov_func = spatial_cov * temporal_cov

        #===============================================================================
        # specify the GP:
        #===============================================================================
        gp = pm.gp.Latent(mean_func=mu, cov_func=cov_func)
        f = gp.prior("f", X=Xs)

        #===============================================================================
        # Likelihood
        #===============================================================================
        Y_ = pm.Normal("Y_", mu=f, sigma = noise, observed=Y)
        
        ################################################################################
        # Sampling
        ################################################################################
        trace = pm.sample_prior_predictive()
        trace.extend (pm.sampling_jax.sample_numpyro_nuts (
                            NUM_SAMPLE, tune=TUNE,\
                            target_accept=0.9, random_seed=RANDOM_SEED,\
                            idata_kwargs = {'log_likelihood': True}, chains=NUM_CHAIN))
1 Like

You could try replacing pm.gp.Latent with pm.gp.HSGP. This is an approximate method that is intended to speed up models with GP latent variables. See this talk by @bwengals for all the details. There are also good usage examples in the docs I linked.

Edit: Actually the nice examples I was thinking of were in the video, not the docs.

2 Likes

Thank you so much for your prompt and valuable advice. I will look into it and report if I find anything interesting.

Again, thank you so much!

Following the example, I quickly tested HSGP with:

 gp = pm.gp.HSGP(m=[35, 35, 35], c=4.0, mean_func=mu, cov_func=cov_func)
 f = gp.prior("f", X=Xs)

But, I had the following error message:

“NotImplementedError: The power spectral density of products of covariance functions is not implemented.”

I guess I am multiplying two covariances here, but other than that, I don’t have a clue. I would appreciate it if any advice is offered.

Thank you.

With the caveat that I’m not a GP expert, here are some thoughts:

  1. You could resort to a linear combination of the kernels? Perhaps model the log of your data instead of the levels? That’s the usual trick with multiplicative models. I’m not 100% sure if the logic holds in the kernel space.
  2. You could make two separate HSGP approximations, one for the spatial and one for the temporal. Extract their HSGP representations (called phi and sqrt_psd, see 45:56 in the video), multiply each by a set of parameters, then multiply together the results. You’d essentially be doing GP(time) * GP(space) “by hand”
  3. Similar to 2, but with fewer parameters – you could extract those basis features and try to multiplying them together, then defining only a single set of parameters for the resulting product?
  4. Nuclear option: see if anyone has come up with a closed-form for the Matern52 * ExpQuad kernel, and use that directly. My understanding is that you’re technically not allowed to multiply GP kernels, but that it’s almost correct so people do it anyway. Since both kernel functions you’re using are exponential (and popular), it seems at least conceivable that there’s a closed form result, and thus there might be research into a more correct kernel to use?

These are just spitballing. I hope a real expert can weigh in as well.

1 Like

To add to what @jessegrabowski said, the prior on lengthscale has a huge effect on sampling time, like 10 minutes to 1 day huge isn’t unreasonable. But ~700 data points really should be doable IMO. The problem with the Gamma(0.5, 0.2) prior is that it puts a lot of mass at very small lengthscales, below the resolution of your data. Say your x data is on a rectangular grid, 1 “unit” apart, as in x[1] - x[0] = 1. Roughly speaking, lengthscale values smaller than 1 (though it’s not a strict boundary) are below the resolution of your data. When this happens the GP starts to turn into a model of IID Gaussian noise – which is already covered by your likelihood.

As a starting point, you could try using an inverse gamma prior. It has a nice property where it has low mass at small values. Using this code, set your lower value carefully, it should at least be the resolution of your x data.

params_ls = pm.find_constrained_prior(
    distribution=pm.InverseGamma,
    lower=2,
    upper=10,
    init_guess={"alpha": 2, "beta": 10},
    mass=0.9,
)
d = pm.InverseGamma.dist(**params_ls)
s = pm.draw(d, 10_000)
plt.hist(s, 100);

I would double check the distribution at the end because every now and then pm.find_constrained_prior will return something that’s not quite right without erroring.

This all applies btw whether or not you’re using HSGP or regular GP. Something to watch out for with HSGPs on spatial / time series models is that it’s not amazing for super small lengthscales, relative to the overall width of your x data. Or in other words for modeling very local relationships between data. To do this you start needing m to get really high and you lose the speed gains.

2 Likes

Thanks very much for your insightful suggestions, as always! I just finished watching the video and will try to get help there.

With respect to the multiplication of kernels, I was trying to implement the Kroncker product assuming separable time and space kernels. I could’ve used PyMC’s some Kron functions - I couldn’t find one that is applicable or I didn’t know enough.

The problem is that all observation locations of the satellite do not have data at a particular measurement time. In other words, depending on the measurement time, some locations have missing data. Because of this, I could not use a fixed input of lon, lat, and time, but I had to make a mesh of lon, lat, and time, excluding those missing locations and times. This lon, lat, time mesh (mn x 3 array) allows me to make an mn x mn (big) matrix for both space and time. And then, I just need to do an element-wise multiplication instead of using the Kronecker product. Here, the active_dims parameter of PyMC was a game changer for me.

I guess my logic seems correct here, but I’d like to hear advice if I am wrong or not efficient.

By the way, I saw a multiplication example from this PyMC doc (though old): Kronecker Structured Covariances — PyMC3 3.11.4 documentation.

Thank you so much!

The Kronecker stuff sounds like an option then!

The problem is that all observation locations of the satellite do not have data at a particular measurement time. In other words, depending on the measurement time, some locations have missing data.

How many are missing? If you use LatentKron (I think this will work for MarginalKron too but I’d have to try it) you can still pass a full grid of X_spatial and X_time. Then index the GP where you actually have observed data and feed that into the likelihood and then do NUTS sampling. You’ll get the result of the GP interpolating the missing spatial/time points “for free”.

Totally orthogonal to the discussion of GPs, but if the number of raster cells in your data is smaller than the number of timesteps, you could also consider writing the model as a statespace system. I’m thinking basically a diagonal VAR(p) with spatially correlated errors. This handles missing data very naturally, and might be a more convenient way to represent things. The wrinkle is that the models we have in PyMC at the moment are just slow for very large state vectors (and yours probably would be – in your case n \times p, where n is the number of rasters and p is the number of lags in the VAR).

That sounds amazing! Although I haven’t calculated, I think the number of missing data points for a given month is > ~90% in the full space-time grid. That’s because the satellite returns to the same location only every few days, and even if it comes back, it passes over a section (i.e., swath) of the land. Then there is a quality issue, so we toss out lots of data points.

If there is a simple toy example for this full grid + indexing, that would be extremely helpful.

Thank you so much!

Thanks for the suggestion! By the way, I saw the new state space (SS) model implementation. I’ve used state space models in R for a while, primarily simple random walk models. I was delighted to see your developing SS models within the PyMC framework.

I guess my raster cells are too many, although someone in my group is testing spatial aggregation before doing any GP or analytical multivariate normal modeling. But your idea of a diagonal VAR sounds very interesting.

I will keep paying attention to your SS models.

Thank you very much!

2 Likes

If there is a simple toy example for this full grid + indexing, that would be extremely helpful.

Sure. Slight modification of this example. I removed all the comments except for what’s relevant for this discussion.

import arviz as az
import matplotlib as mpl
import numpy as np
import pymc as pm
plt = mpl.pyplot


RANDOM_SEED = 12345
rng = np.random.default_rng(RANDOM_SEED)

n1, n2 = (50, 30)
x1 = np.linspace(0, 5, n1)
x2 = np.linspace(0, 3, n2)

X = pm.math.cartesian(x1[:, None], x2[:, None])

l1_true = 0.8
l2_true = 1.0
eta_true = 1.0

cov = (
    eta_true**2
    * pm.gp.cov.Matern52(2, l1_true, active_dims=[0])
    * pm.gp.cov.Cosine(2, ls=l2_true, active_dims=[1])
)

K = cov(X).eval()
f_true = rng.multivariate_normal(np.zeros(X.shape[0]), K, 1).flatten()

sigma_true = 0.25
y_true = f_true + sigma_true * rng.standard_normal(X.shape[0])

# indices of the y data we're keeping
keep_percentage = 0.5
n_keep = int(np.round(keep_percentage * len(y_true)))
ix = np.arange(len(y_true))
observed_ix = np.random.choice(ix, size=n_keep, replace=False)
y = y_true[observed_ix] # now your observed data is a selection of y_true, and you know the indices


with pm.Model() as model:
    ls1 = pm.Gamma("ls1", alpha=2, beta=2)
    ls2 = pm.Gamma("ls2", alpha=2, beta=2)
    eta = pm.HalfNormal("eta", sigma=2)

    cov_x1 = pm.gp.cov.Matern52(1, ls=ls1)
    cov_x2 = eta**2 * pm.gp.cov.Cosine(1, ls=ls2)

    sigma = pm.HalfNormal("sigma", sigma=2)

    gp = pm.gp.LatentKron(cov_funcs=[cov_x1, cov_x2])
    f = gp.prior("f", Xs=Xs)

    # This line is the only one that's different.  Index the GP where you observed data
    f_obs = f[observed_ix]
    y_ = pm.Normal("y_", mu=f_obs, sigma=sigma, observed=y)

Now your trace of f will cover the full grid of your data. It will give you the same posterior of f as if you used the GP to “predict” what the f values would be there.

If you have 700 total data points, how big is your spatial grid, and how many points in time?

1 Like

Thanks very much for this piece of code! I had not realized the impact of the length scale on the speed (or maybe also on model performance). Yes, I will try this.

Thank you!

Wow, thank you so much! I am getting a lot of help today. I really appreciate your time!

I love your statement:

It will give you the same posterior of f as if you used the GP to “predict” what the f values would be there.

It is really great that we do get predictions while doing inference!

For the time, we are doing hourly modeling by month for the entire 2020. So, for a month, the full time index will be a size of 24*30.

For the exact spatial domain boundary, I’d have to ask my postdoc, but it is roughly the Southern California area (LA County + suburban areas). The satellite resolution is ~ 2 km. My postdoc is using an analytical solution, and, in parallel, I am now into PyMC GP!

Thank you!

Hi Bill,

Following your example, I tried to use

gp = pm.gp.LatentKron(cov_funcs=[cov_x1, cov_x2])
f = gp.prior("f", Xs=Xs)

But I have difficulty in setting up Xs.

Following Kronecker Structured Covariances — PyMC3 3.11.4 documentation, which seems to be the example you are using, I set up Xs as:

Xs = [lons[:, None], lats[:, None]], dt_full_norm[:, None]]

where lons and lats are unique longitudes and latitudes, which will make a full grid if I use np.meshgrid. For time, dt_full_norm is the entire time stamp.

Then I applied as:

spatial_cov = pm.gp.cov.Matern52(input_dim=3, ls=[length_S, length_S], active_dims=[0,1])
temporal_cov = var*pm.gp.cov.ExpQuad(input_dim=3, ls=length_T, active_dims=[2])

gp = pm.gp.LatentKron(cov_funcs=[spatial_cov, temporal_cov])
f = gp.prior("f", Xs=Xs)

Then I have the following error message:

ValueError: Must provide a covariance function for each X

I think this happens because my Xs is a list instead of n x 3 array. I used to use n x 3 array, which was manually constructed by looking at only data points where observations are available. But Xs = [lons[:, None], lats[:, None]], dt_full_norm[:, None]] yields a list.

I think I am missing something here. Any suggestion?

Thank you.

I think you just need to change input_dim and active_dims. Since it’s separable, each cov_func is just acting over each part of Xs.

Can you try it with this?

spatial_cov = pm.gp.cov.Matern52(input_dim=2, ls=[length_S, length_S])
temporal_cov = var*pm.gp.cov.ExpQuad(input_dim=1, ls=length_T)

EDIT: changed input_dim on temporal_cov from 2 to 1.

1 Like

Thank you so much! I quickly tried it, but it didn’t work. However, from your comment, I have an idea of how I should try in a few different ways. If none of my tries work, I will ask for your help!

Again, thanks so much!

1 Like