Complicated PyMC Mixture model

Assume I have a model \mathcal M that can predict the observables \mathcal O_1, \ldots, \mathcal O_M, for the times \tau_1, \ldots, \tau_N. Similarly, I can measure the same observables, at the same times with uncertainty \delta \mathcal O_1,\ldots, \delta\mathcal O_M.

Now I assume that I have several models \mathcal M_1,\ldots,\mathcal M_K with a shared set of parameters \theta. I wish to compare these models using a linear Mixture method and learn the parameters \theta, but in a way that the weights for each model can be different for the different observation times \tau_n. Lastly, assume that the models are computationally expensive, so we use Gaussian Processes to emulate them.

There are two ways to do this:

  1. Use the pm.Mixture method
  2. Write my own mixture method

I have questions regarding both, but before I drown you in the example code, is there any advice you can give based on the problem description above?

Proceed at your own risk
  1. If I use the builtin pm.Mixture, then I can only mix models at one observation time at a time. The code might look something like this (including the necessary code of the GPs). Here the problem is that passing the data separately to every mixture model can make the construction of the PyTensor graph take ungodly long, for N sufficiently large. Also, an emulator has to be constructed for each observable at each observation time. This could certainly be sped up by having a multi-output GP and use the observation times as training parameters, but I am not confident enought in my PyMC fluency.
data = ...  #  np.ndarray with shape (N, M)
error = ... #  np.ndarray with shape (N, M)
parameter_ranges = ...  #  np.ndarray with shape (R, 2)
with pm.Model() as gp_emulators:
        # setup for gp emulator
        cov_func = pm.gp.cov.Matern32(
            input_dim=R, 
            ls=np.diff(parameters_ranges, axis=1)
        )

        # dictionary to store all gps
        # assume `model_names` exists
        emulators = dict((key, []) for key in model_names)  
        for name in model_names:
            # assume `emulator_training_data` exists, and different form `data`
            for i, training_data in enumerate(emulator_training_data[name]):  
                observable_emulators = []
                for j, observable in enumerate(['O_1', ..., 'O_M']):
                    observable_emulators.append(
                        pm.gp.Marginal(cov_func=cov_func)
                    )
                    observable_emulators[-1].marginal_likelihood(
                        name=f'{name}_{observable}_{i}',
                        X=emulator_design_points,
                        y=training_data[:, j],
                        sigma=0,
                    )
                emulators[name].append(observable_emulators)

        inference_vars = pm.Uniform(
            'inference_vars',
            lower=parameters_ranges[:, 0],
            upper=parameters_ranges[:, 1],
            shape=(R, 1)
        )

        # predict observable for inference parameters
        for i, tau in enumerate(observation_times):
            comp_dists = [
                pm.MvNormal.dist(
                    mu=[
                        emulators.conditional(
                           name=f'normal_{name}_{observable}_{tau}',
                           Xnew=inference_vars,
                        )
                        for j, observable in enumerate(['O_1', ..., 'O_M'])
                    ],
                    cov=np.diag(error[i]),
                )
                for name in hydro_names
            ]

            # Construct Mixture model
            alpha = pm.Lognormal(
                f'alpha_{i}',
                mu=0.0,
                sigma=1.0,
                shape=len(hydro_names)
            )
            weights = pm.Dirichlet(f'Dirichlet_{i}', a=alpha)
            pm.Mixture(
                f'mix_{i}',
                w=weights,
                comp_dists=comp_dists,
                # Note that I am passing data multiple times
                observed=observation_data[i].reshape(-1, 1),
            ) 
  1. How do multiply a PyTensor random variable (a weight) by a distribution (the likelihood of a model)?