Hello!

I have some 3d data (2d spatial surface, sampled 500 times) that I’d like to model with a Gumbel distribution. For the mean of the Gumbel distribution, I’d like to use a 2d (spatial) GP. My trouble is with aligning the GP prior up with the actual data. Since the data is d1 x d2 x 500 and the GP prior is just d1 x d2, I think I need to repeat each value in the prior 500 times. Here’s a minimum (not) working example:

```
import numpy as np
import xarray as xr
import pymc as pm
import pymc.sampling_jax
import aesara.tensor as at
from scipy.stats import gumbel_r
# simulate fake data; format as dataset
data = xr.Dataset(
data_vars={
"Y": (
["x1", "x2", "sample"],
gumbel_r.rvs(loc=8, scale=4, size=(24, 24, 500)),
)
},
coords={
"x1": np.linspace(-2, 2, 24),
"x2": np.linspace(-2, 2, 24),
"sample": np.arange(500),
},
)
# convert to vector-valued data
df = data.to_dataframe().reset_index().sort_values(by=["x1", "x2"])
Y = df["Y"]
X = df[["x1", "x2"]].drop_duplicates() # -> has length 24*24
with pm.Model(coords={"predictors": X.columns.values}) as mwe_model:
## mu parameter
# mean
beta0_mu = pm.StudentT("beta0_mu", nu=5, mu=5, sigma=2)
coefficients_mu = pm.StudentT("coefficients_mu", nu=5, mu=0, sigma=5, dims="predictors")
mean_mu = pm.gp.mean.Linear(coefficients_mu, intercept=beta0_mu)
# covariance
scale_mu = pm.HalfCauchy("scale_mu", beta=1)
range_mu = pm.HalfCauchy("range_mu", beta=1, shape=(2,))
cov_mu = scale_mu**2 * pm.gp.cov.Matern32(2, range_mu)
gp_mu = pm.gp.Latent(mean_func=mean_mu, cov_func=cov_mu)
f_mu = at.repeat(gp_mu.prior("gp_mu", X=X.values), 500, axis=0) # this is the line I'm curious about
## beta parameter
beta = pm.HalfStudentT("beta", nu=5, sigma=3)
Y_ = pm.Gumbel("spatial_gumbel", mu=f_mu, beta=beta, observed=Y.values)
pymc.sampling_jax.sample_numpyro_nuts(
tune=1500,
target_accept=0.9,
postprocessing_backend="gpu",
chain_method="parallel",
# ignore log-likelihood due to memory error in calculation
idata_kwargs={"log_likelihood": False},
)
```

This gives me the following error:

```
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [288000]. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
```

I think the line that’s causing the problem is ` f_mu = at.repeat(gp_mu.prior("gp_mu", X=X.values), 500, axis=0)`

. Any suggestions on how to set this up correctly? Thanks!