How to increase sampling speed with pm.sample

Hi
I’m trying to apply kriging to a data set of about 400 data points
however it takes about a day to sample which does not seem right
python = 3.9.12
pymc3 = 3.11.5

I get the warning

WARNING (theano.tensor.blas): Using NumPy C-API based implementation for BLAS functions.

Which I think is the cause however I cant seem to resolve it,

I have tried the solution of:
conda install numpy scipy mkl-service libpython m2w64-toolchain
conda install theano pygpu
conda install pymc3

and creating .theanorc.txt

however its not worked

    with pm.Model() as model:
        c = pm.Normal(name='c', mu=0, sigma=10,shape=1)
        ls = pm.HalfNormal(name='ls', sigma=1,shape=3)
        k = pm.Exponential('k',0.1)
        
        cov_func = pm.gp.cov.ExpQuad(input_dim=3,ls=ls)
        mean_func = pm.gp.mean.Constant(c)
        noise = 5

        gp = pm.gp.Marginal(mean_func, k*cov_func)
        f = gp.marginal_likelihood("f", X, y, noise)
     
        
    with model:
        trace = pm.sample(1000,chains=2,return_inferencedata=False)

    pm.traceplot(trace)
    plt.show()

Any help would be really appreciated thanks

You are correct that the BLAS warning is likely the cause for the poor performance. If you can, I would strongly suggest installing PyMC version 5 (installation instructions here) before trying to get too far in fixing your current v3 environment.

1 Like

Yeah, not having BLAS will definitely slow things down, though GPs do tend to be a little slow (in your case, inverting a 400x400 matrix). I usually have luck installing via pip instead of conda when this happens.

Definitely use either version 4 or 5 of PyMC, rather than PyMC3 as Christian suggests.

Also, you may want to try using pymc.sampling_jax.sample_numpyro_nuts instead of pymc.sample – it will generally be faster.

1 Like

Thank you, I’m having trouble installing numpyro for JAX sampling, I’m on windows, would I need to use Ubuntu?

It’s still taking an excessively long time however the blas warning has gone with creating the pymc 5 environment

As @fonnesbeck said, GPs tend to be pretty slow. However, if you put together a self-contained example, someone else can run it and give you an estimate of how long you might expect sampling to take.

Ah, yeah Jax is still pretty experimental on Windows, so running it from Linux or Mac is preferred. You can spin up a WSL container on Windows-- that’s typically what I do.

Another approach is to use variational inference rather than sampling. If you call pm.fit instead of pm.sample that will run VI, which will be fast (though not as good of an estimate as you’d get from MCMC, in general).