How to use a TensorVariable in GP predict method

Thank you Jesse and Daniel (you can call me Kevin), for the help so far.

Before switching to PyMC, I had been using ptemcee where I did something the similar. Here the emulators were trained (using sklearn’s GaussianProcessRegressor) on training data that was independent from the inference data. This is one flaw in the code I shared above, the training data for the GP should not be the same as the inference data. The point of the emulators is to facilitate cheap model evaluations.

There is another nuance, I want the shape parameters of the Dirichlet distribution to have dependence on another variable \theta which can be measured in an experiment. In addition to learning the value \mathbf x_\text{true}, I want to learn the weights’ dependence on some measurable parameter \theta (which need not appear in the vector \mathbf x) in the experiment.

So code might look something like this

K = ... # Number of models
N = ... # Number of points for theta
 
emulator_training_x = ...
emulator_training_y = ...
observation_y = ... # has shape (N,)

with pm.Model() as m0:
    cov_func = pm.gp.cov.Matern32(1, ls=[10])
    emulators = []
    for k in range(K):
        gp = pm.gp.Marginal(cov_func=cov_func)
        emulators.append(
            gp.marginal_likelihood(
                f'model_emulator_{k}',
                X=emulator_training_x,
                y=emulator_training_y,
                sigma=0
            )
        )

    x = pm.Uniform(...)
    for n in range(N):
        comp_dist = [
            pm.Normal.dist(
                *emulator[k].predict(Xnew=x, diag=True),
                observed=observation_y[n]
            )
            for k in range(K)
        ]
        shape_parameters = pm.LogNormal(
            f'shape_parameters_{n}',
            mu=0,
            sigma=1,
            shape=K
        )
        weights = pm.Dirichlet(f'weight_{n}', a=shape_parameters)
        pm.Mixture(f'mixture_{n}', comp_dist=comp_dist, w=weights)
    
    idata = pm.sample(1_000_000, ...)