For those that find there way here. A solution that I ended up using is to construct RVs in my model using the gp.Marginal.conditional function and specified the shape of the x_rv RV.
In the end, it looks something like this:
inference_parametes = [...] # List of strings
with pm.Model() as inference_model:
inference_vars = pm.Uniform(
'inference_vars',
lower=inference_parameters_ranges[:, 0],
upper=inference_parameters_ranges[:, 1],
shape=(len(inference_parameters), 1)
)
for i, tau in enumerate(observation_times):
comp_dists = [
pm.MvNormal.dist(
mu=[
pm.gp.Marginal(cov_func=cov_func).conditional(
name=f'{name}_{observable}_{i}',
Xnew=inference_vars,
given={
'X': emulator_design_points,
'y': emulator_training_data[name][i, :, j + 1],
'sigma': 0
},
shape=(len(inference_parameters),)
)
for j, observable in enumerate([...])
],
cov=np.diag(observation_error[i]),
)
for name in hydro_names
]
alpha = pm.Lognormal(
f'alpha_{i}',
mu=0.0,
sigma=1.0,
shape=len(hydro_names)
)
weights = pm.Dirichlet(f'Dirichlet_{i}', a=alpha)
pm.Mixture(
f'mix_{i}',
w=weights,
comp_dists=comp_dists,
observed=observation_data[i, 1:].reshape(-1, 1),
)
The problem with this is that it take forever to build for the number of observation times larger than 4. I will open another question regarding passing of observation data.