Correct way to use linear regression as mean of a GP

I want to fit the following model:

\begin{align}y &\sim \text{MvN}(\mu, \Sigma)\\ \mu &= X\beta \\ \Sigma &= f(Z, \ell, \sigma) \\ \ell, \mu, \sigma &\sim \text{priors} \end{align}

Where f is some covariance function, and X and Z are feature matrices. I guess I can do this with a GP, but I’m not clear on the “canonical” way to insert the linear regression part, \mu = X\beta, into the model. I have found three different ways of doing things scattered among the docs:

1. In the salmon spawning example, the empirical slope of the data is used as a constant without estimating anything.
2. In the spatial point-patterns example, a random variable is passed to pm.gp.mean.Constant, but no data.
3. In the Statistical Rethinking port, a subclass of gp.Mean is created to hold random variables, but data is accessed via global variables.

Combining all three, my guess for the cleanest way would be to make a deterministic and give it to pm.gp.mean.Constant, as in:

X = pm.MutableData('X', df[features], dims=['index', 'features'])
alpha = pm.Normal('alpha')
beta = pm.Normal('beta', dims=['features'])
regression_mean = pm.Deterministic('regression_mean', alpha + X @ beta, dims=['index', 'features'])
mean_func = pm.gp.mean.Constant(regression_mean)

# rest of gp stuff as usual

Is this kosher? (1) makes me think not, especially in the broader context of the lecture (it just gave a linear regression example above, then quietly abandons it in favor of this heuristic slope). But (2) makes me think it should work fine. But if it is fine, why is the subclassing needed in (3)?

CC @bwengals

The purpose of the mean function is to make both prediction and gp.Marginal work. The mean function basically acts as a super simple passthrough for the non-GP parts of the model. I think for your case it’d be easiest to either use the Linear mean function or write your own subclass in the same way as (3). The idea is that the predictors you’d change with pm.set_data, so X, are always the only input to __call__, and then you can stash anything else you need when you define the mean function in __init__, like alpha and beta. I think the Linear mean function might be what you need, or a subclass that’s very similar.