Oh ok, I think I get it now. How about this? You can combine covariance functions (in general, not talking about pymc3 specifically) like this:
f(x)k(x, x')f(x')
where f(x) is some scaling function of x or maybe other parameters.
But in your case, f(x) is just x4, right? Here is a class that you can use to do this. I haven’t tested it carefully so, keep an eye on it. This really should be added to the codebase…
class ScalingFunc(pm.gp.cov.Covariance):
def __init__(self, input_dim, sc_func, args=None, active_dims=None):
super(ScalingFunc, self).__init__(input_dim, active_dims)
self.sc_func = pm.gp.cov.handle_args(sc_func, args)
self.args = args
def full(self, X, Xs=None):
X, Xs = self._slice(X, Xs)
sc_x = self.sc_func(tt.as_tensor_variable(X), self.args)
if Xs is None:
return tt.outer(sc_x, sc_x)
else:
sc_xs = self.sc_func(tt.as_tensor_variable(Xs), self.args)
return tt.outer(sc_x, sc_xs)
def diag(self, X):
return tt.square(self.sc_func(tt.as_tensor_variable(X), self.args))
And a simple example:
X = np.linspace(0, 10, 100)[:, None]
sc_func = lambda X: X
k3 = pm.gp.cov.ExpQuad(1, 1)
k4 = ScalingFunc(1, sc_func)
k = k3*k4
m=plt.imshow(k(X).eval());plt.colorbar(m);
The scaling function, sc_func, is used the same way as the function inputs for
WarpedInput and Gibbs, so to generalize check out the docs on those.