Multiprocessing GP extremely slow starting from certain size

Hi!
I find that starting from 96 points multiprocessing GP suddenly becomes much slower compared to both singleprocessing and 95-points-multiprocessing. This is easily reproducible:

import numpy as np
import pymc3 as pm

n = 96
X = np.random.rand(n, 1)
observed = np.random.rand(n)

with pm.Model() as model:
    cov = pm.gp.cov.ExpQuad(1, 1)
    gp = pm.gp.Latent(mean_func=pm.gp.mean.Zero(), cov_func=cov)
    gp_f = gp.prior('gp_f', X=X)
    val_obs = pm.Normal('val_obs', mu=gp_f, sd=0.1, observed=observed)

    trace = pm.sample(njobs=2)

goes like

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [gp_f_rotated_]
Sampling 2 chains:  49%|████▉     | 981/2000 [01:14<01:32, 11.01draws/s]

While the same code with 95 instead of 96 finishes in 10 seconds. Or njobs=1 instead of 2 finishes in 18 seconds for 96 points.
I notice that with 96 or more points each process is already multithreaded and takes 600% of a 6-core (12 with HT) processor. With 95 or less each process takes just 100% cpu.
Also this is certainly related to a older post of mine, Sampling doesn't start when njobs > 1 for some models. Then sampling didn’t even start in this case, and now it starts but goes extremely slow.

@bwengals, do you have some idea what happen here?

What blas library is your computer using? openblas or mkl? These libraries optimize how many threads to use depending on matrix size. You can avoid these problems by just setting njobs=1 and chains=2 which tells pymc3 to not parallelize sampling over chains.

It uses MKL, as shown by np.distutils.__config__.show(). Yes, I understand that setting njobs=1 helps with the slowdown, and gave the corresponding timings in the first post. However, this means I should always check if number of points is <= 95 or not and choose njobs to get the fastest sampling. And more than an order of magnitude slowdown of using njobs > 1 instead of 1 does seem very strange. The default option for pymc is to use multiple njobs, which in such cases leads to very slow sampling.

It seems that most pymc3 models don’t use too much matrix math, which means that a default of running multiple chains in parallel is best in most cases. For matrix heavy models, this choice is a little suboptimal, but works well. For something like a GP, its pretty much 100% cholesky decompositions and inner products happening under the hood, and unless the data set is fairly small (it looks like MKL determines small to be n <= 95 on your machine), forcing pymc3 to sample multiple chains sequentially works better than the default setting. Since sampling multiple chains is important for diagnosing sample quality and convergence, it’s good to have that be default behavior.

2 Likes