How to improve speed and accuracy of sparse gaussian process multi-class classification?

I want to improve the speed and accuracy of sparse gaussian process multi-class classification.
Currently this is using iris dataset for experimentation, but later I will use the model to fit dataset with several thousands datapoints, so it needs to be sparse.

file1.py (6.4 KB)

What kind of accuracy problem you are seeing?

Improving the speed is likely difficult, as your model seems quite compact thus the speed problem is probably the current bottleneck in theano. Maybe more informative prior could help but I am not sure.

The train and test accuracy of when using marginal sparse and latent reparametrize=False is around 30% (in other words, it is random guessing). However, when reparametrize=True it is around 80-90%, so I know it should work but doesn’t when using marginal sparse. When I looked at the marginal code it seems like latent code with reparametrize=False, which is why I used it as a surrogate to see what is the problem with the rest of my model.
I am suspicious about the optimization of MvNormal when reparametrize=False because it keep tending to 0 eta and 0 lengthscale. When looking at the logp from find_MAP, it keeps increasing to higher logp, even if I initialized it to the best eta and ls based on the latent result and the logp is around -300, after optimization it become around 1200 (it is also true for advi, the loss increases rather than decreasing).

So your implementation using Latent is correct. In Latent, reparameterize=True is nearly always the best thing to do, to the point where it may not be worth it to include it as an option to not reparamterize. In MarginalSparse, you’re setting is_observed, not reparameterize.

There’s no LatentSparse implementation in PyMC3 yet, which is what you need. Reading it in the context of what you’re trying to do, the docstring for MarginalSparse is very much not clear at all about this. There is a PR in progress, which doesn’t use a variational formulation, just the DTC approximation. Also see this gist.

1 Like

May I know the difference between the two?
Qss = tt.dot(tmp.T, tmp) #Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)

Both of them seems the same when I tried on Iris dataset.
Do one of them is suppose to have higher accuracy but slower and vice versa?

Yes they’re the same. The version with pinv is slower and less numerically stable.