Hey I’ve been enjoying PyMC pretty much so thanks to the devs and the community for keeping up the good work!
For an assignment at my uni I’ve been toying with GPs a lot mainly for their potential to approximate a variety of functions. My interest now is to fit multidimensional surfaces from a set of data points. I’ve came up with the following code:
def Fit_GP(X, f_true):
with pm.Model() as model: ℓ = pm.HalfCauchy("ℓ", beta=1, shape = 2) η = pm.HalfCauchy("η", beta=5) cov = η**2 * pm.gp.cov.Matern52(2, ℓ) gp = pm.gp.Marginal(cov_func=cov) y_ = gp.marginal_likelihood("y", X = X, y = f_true, noise = 1e-10) trace = pm.sample(draws = 100) pm.traceplot(trace) summ = pm.summary(trace) print(summ) x_min = np.min(X, axis=0); x_max = np.max(X, axis = 0); n = len(X); n *= 1 x_new = np.linspace(x_min, x_max, int(n)) x1, x2 = np.meshgrid(x_new, x_new) X_new = np.concatenate((x1.reshape((-1,1)), x2.reshape((-1,1))), axis=1) # add the GP conditional to the model, given the new X values mu, var = gp.predict(X_new, diag = True) #mu, var = 1,1 with model: f_pred = gp.conditional("f_pred", X_new) pred_samples = pm.sample_posterior_predictive(trace, vars=[f_pred], samples=10) return mu, var, pred_samples['f_pred'], x1, x2
I randomly generate data from a Mattern52 kernel and a 0 mean GP and then try to fit the GP parameters using the “Inference button”. This kind of procedure works well in 1D GPs but when scalling it to 2D I find that the sample_posterior_predictive function stalls and is unable to produce results (even with a meager 10 samples), this apparently happens because 60Gb arrays pop up in the operation and python just crashes.
My question is then: Can someone point me to a way to perform this more efficiently? (perhaps using the sparse approximation implementation but the problem crashes when using 6 points in each axis so I’m kinda discourage from using it, but alas I’m not fluent with the details of the algorithm so I can’t really say it wont work). My intent is to extend this to more dimensions and then use it to model a real-valued function.