Gaussian Process predict over trace

Hi everyone,

I’d like to integrate a function, a, of the predicted mean (μ(x)) and standard deviation (σ(x)) over the inferred model parameters in a Gaussian Process model in PyMC3.

If θ are the collected parameters of the GP model (e.g. noise, length scales etc. of the covariance function) then I want:

a(μ(x),σ(x))=∫dθa(μ(x),σ(x);θ)

My model is specified and conditioned like this:

with pm.Model() as model:
    l = pm.Gamma("l", alpha=2, beta=1)
    eta = pm.HalfCauchy("eta", beta=5)

    cov = eta**2 * pm.gp.cov.Matern52(1, l)
    gp = pm.gp.Marginal(cov_func=cov)

    sigma = pm.HalfCauchy("sigma", beta=5)
    y_ = gp.marginal_likelihood("y", X=x, y=y, noise=sigma)

    trace = pm.sample(1000)

I can predict a mean and standard devation over a single point in the trace:

mu, var = gp.predict(X=Xnew, point=trace[42])

and then calculate my function:

a = my_func(mu, var)

To average over the trace I’m doing this:

for i in range(len(trace)):
    mu, var = gp.predict(X=Xnew, point=trace[i])
    a[:, i] = my_func(mu, var)
integrated_a = a.mean(axis=1)

But it’s very, very slow. Is there a neat / faster way of doing this? A general strategy would be really helpful. I’m not an expert in Theano but willing to learn.

Thanks in advance.

To do exactly this, no it’s not possible, primarily because theano doesn’t seem to support a cholesky decompositon over n-dimensional arrays, so that the covariance matrices can be stacked. To integrate out \theta, can you use sample_ppc?

Thanks for your reply. I thought of using sample_ppc, something like:

with model_sample:
    f = gp.conditional('f', Xnew)
    pred_samples = pm.sample_ppc(trace, vars=[mean(f), diag_cov(f)], samples=2000)

where mean(f) and diag_cov(f) would be the mean vector and diagonal of the covariance matrix of the random variable f. I just don’t know how to properly specify mean(f) and diag_cov(f).

EDIT: OK, I guess what you’re saying applies here too, so you can’t put something like diag_cov(f) into sample_ppc. Damn.