Hi All,

I’ve been working on adding more features to my PyMC model for a couple of months now. I feel that I’m making progress, especially with the advanced features available in PyMC.

I’m now modeling spatio-temporal satellite data, and the model is shown below. While the model seems to generate the expected posterior distribution, it runs very slowly.

Interestingly, without optimizing the length scale (i.e., when using a fixed value for the length scale), the JAX-enabled Gaussian process finishes in about 10 minutes. However, once I optimize the length scale parameters (for both space and time), the processing time shoots up to more than 1 day on a CPU and over 8 hours on Colab’s single T4 GPU with JAX enabled.

Given that my total number of observations is around 700, the long processing time isn’t entirely unexpected. However, the duration seems excessive, even for a JAX-based simulation. Using multiple GPUs only reduces the time to about 4 hours (i.e., half of the 8 hours) if I use 1500 samples (with 1000 warm-up samples) for each chain.

I would greatly appreciate any advice on improving the speed of the code.

Thank you.

```
def ST_model ():
with pm.Model() as model:
normal_dist = pm.Normal.dist(mu=mu_lambda, sigma=sigma_lambda)
Lambda = pm.Truncated("Lambda", normal_dist, shape=m, lower=0, upper=10)
#===============================================================================
# Mean function
#===============================================================================
mu = LinearMean(K = K, Lambda = Lambda)
#===============================================================================
# Kernel variance
#===============================================================================
var = pm.Exponential("kernel_var", 2)
#===============================================================================
# Kernel length
#===============================================================================
length_T = pm.Exponential("kernel_length_T", 1/(7/(31*NUM_MONTHS)))
length_S = pm.Gamma("kernel_length_S", 0.5, 0.2)
#===============================================================================
# Noise
#===============================================================================
noise = pm.HalfCauchy ("noise", 1, shape = 1)
#===============================================================================
# Covariance function
# Note we are not using kronecker product
#===============================================================================
# Spatial covariance
spatial_cov = pm.gp.cov.Matern52(input_dim=3, ls=[length_S, length_S], active_dims=[0,1])
# Temporal covariance function, with kernel variance
temporal_cov = var * pm.gp.cov.ExpQuad(input_dim=3, ls=length_T, active_dims=[2])
# Combined spatio-temporal covariance.
cov_func = spatial_cov * temporal_cov
#===============================================================================
# specify the GP:
#===============================================================================
gp = pm.gp.Latent(mean_func=mu, cov_func=cov_func)
f = gp.prior("f", X=Xs)
#===============================================================================
# Likelihood
#===============================================================================
Y_ = pm.Normal("Y_", mu=f, sigma = noise, observed=Y)
################################################################################
# Sampling
################################################################################
trace = pm.sample_prior_predictive()
trace.extend (pm.sampling_jax.sample_numpyro_nuts (
NUM_SAMPLE, tune=TUNE,\
target_accept=0.9, random_seed=RANDOM_SEED,\
idata_kwargs = {'log_likelihood': True}, chains=NUM_CHAIN))
```