GPs & minibatch: shape mismatch?

I’m trying to model a linear trend subject to some (more or less) known spatial covariance, and I have a lot of points, so I’d like to model the trend as a line, model the spatial covariance with a 2D GP, and use minibatch to parcel everything out into friendlier data sizes.

Following the example of this tutorial, I create minibatches of the spatial data, the variable underlying the linear model, and the observations:

indep_m, dep_m, spatial_X_m = pm.Minibatch(indep), \ # 'x' for linear model
                              pm.Minibatch(dep), \ # observations
                              pm.Minibatch(spatial_X) # spatial coords

I then create a model that uses the minibatches

with pm.Model() as metgrad_linearmodel_mini:
    # Define priors on linear regression
    sigma = pm.HalfCauchy('logOH12-sigma', beta=.5, testval=.1)
    intercept = pm.Normal('logOH12-x0', 8.5, sigma=20)
    x_coeff = pm.Normal('logOH12-slope', 0, sigma=20)
    
    # Define priors on spatial covariance
    psf_ls = 1.5
    psf_eta = pm.HalfCauchy('PSF-eta', beta=.1, testval=.01)
    
    # Define the GP itself
    psf_cov = psf_eta**2. * pm.gp.cov.ExpQuad(input_dim=2, ls=psf_ls)
    psf_gp = pm.gp.Latent(cov_func=psf_cov)
    noise_from_psf = psf_gp.prior('PSF', spatial_X_m, shape=dep_m.shape.eval())
    
    met_at_indep = pm.Deterministic(
        'logOH12-at-indep',
        intercept + (x_coeff * indep_m))

    # Define likelihood
    likelihood = pm.Normal(
        'met', mu=met_at_indep + noise_from_psf, sigma=sigma, observed=dep_m, total_size=dep_o.shape)
    
    approx = pm.fit(100000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])

This is where things get hairy…

I thought I would be able to define a similar model for the full data, but one that referred to indep, dep, and spatial_X (instead of indep_m, dep_m, and spatial_X_m):

with pm.Model() as metgrad_linearmodel:
    # Define priors on linear regression
    sigma = pm.HalfCauchy('logOH12-sigma', beta=.5, testval=.1)
    intercept = pm.Normal('logOH12-x0', 8.5, sigma=20)
    x_coeff = pm.Normal('logOH12-slope', 0, sigma=20)
    
    # Define priors on spatial covariance
    psf_ls = 1.5
    psf_eta = pm.HalfCauchy('PSF-eta', beta=.1, testval=.01)
    
    # Define the GP itself
    psf_cov = psf_eta**2. * pm.gp.cov.ExpQuad(input_dim=2, ls=psf_ls)
    psf_gp = pm.gp.Latent(cov_func=psf_cov)
    noise_from_psf = psf_gp.prior('PSF', spatial_X)
    
    met_at_indep = pm.Deterministic(
        'logOH12-at-indep',
        intercept + (x_coeff * indep))

    # Define likelihood
    likelihood = pm.Normal(
        'met', mu=met_at_indep + noise_from_psf, sigma=sigma, observed=dep)

    # Inference!
    step = pm.NUTS(scaling=approx.cov.eval(), is_cov=True)
    start_full = approx.sample()[0]
    trace = pm.sample(500, tune=500, cores=1, step=step, start=start_full)

which throws the error

ValueError: Bad shape for start argument:
Expected shape (1784,) for var 'PSF_rotated_', got: (128,)

The error itself makes sense, since the Latent GP describing the covariate noise is now being evaluated at a larger number of points. For that matter, approx.cov also has the wrong shape. So even if I try to initialize using only the entries from start_full for logOH12-slope, logOH12-x0, and logOH12-sigma, I have to ignore the covariance matrix that the minibatch fit worked so hard to find, so that seems to defeat the purpose.

I’m wondering (a) if minibatching is even a valid thing to do in the GP context, and if so, (b) whether there’s a workaround that retains the benefits of knowing the covariance and being able to scale NUTS accordingly.

Any updates on this?