Gaussian Process predict over trace


#1

Hi everyone,

I’d like to integrate a function, a, of the predicted mean (μ(x)) and standard deviation (σ(x)) over the inferred model parameters in a Gaussian Process model in PyMC3.

If θ are the collected parameters of the GP model (e.g. noise, length scales etc. of the covariance function) then I want:

a(μ(x),σ(x))=∫dθa(μ(x),σ(x);θ)

My model is specified and conditioned like this:

with pm.Model() as model:
    l = pm.Gamma("l", alpha=2, beta=1)
    eta = pm.HalfCauchy("eta", beta=5)

    cov = eta**2 * pm.gp.cov.Matern52(1, l)
    gp = pm.gp.Marginal(cov_func=cov)

    sigma = pm.HalfCauchy("sigma", beta=5)
    y_ = gp.marginal_likelihood("y", X=x, y=y, noise=sigma)

    trace = pm.sample(1000)

I can predict a mean and standard devation over a single point in the trace:

mu, var = gp.predict(X=Xnew, point=trace[42])

and then calculate my function:

a = my_func(mu, var)

To average over the trace I’m doing this:

for i in range(len(trace)):
    mu, var = gp.predict(X=Xnew, point=trace[i])
    a[:, i] = my_func(mu, var)
integrated_a = a.mean(axis=1)

But it’s very, very slow. Is there a neat / faster way of doing this? A general strategy would be really helpful. I’m not an expert in Theano but willing to learn.

Thanks in advance.


#2

To do exactly this, no it’s not possible, primarily because theano doesn’t seem to support a cholesky decompositon over n-dimensional arrays, so that the covariance matrices can be stacked. To integrate out \theta, can you use sample_ppc?


#3

Thanks for your reply. I thought of using sample_ppc, something like:

with model_sample:
    f = gp.conditional('f', Xnew)
    pred_samples = pm.sample_ppc(trace, vars=[mean(f), diag_cov(f)], samples=2000)

where mean(f) and diag_cov(f) would be the mean vector and diagonal of the covariance matrix of the random variable f. I just don’t know how to properly specify mean(f) and diag_cov(f).

EDIT: OK, I guess what you’re saying applies here too, so you can’t put something like diag_cov(f) into sample_ppc. Damn.