Parallelizing chains with custom likelihood on multiple cores

I think the original model is so small, that the parallelization doesn’t do anything useful.
If we try something a bit more costly it does:

with pm.Model() as model:
    x = pm.Normal("x", shape=(50, 50))
    y = pt.linalg.solve(x + x.T, np.ones(50)).sum()
    pm.Normal("y", mu=y, observed=100.)

times = []
for cores in [1, 4, 8]:
    start = time.time()
    with model:
        pm.sample(cores=cores, chains=cores, tune=100, draws=100, compute_convergence_checks=False)
    end = time.time()
    times.append(end - start)

This gives me [68.90292978286743, 92.38360643386841, 97.97894310951233], so sampling just one chain is a bit faster than the others, but not much. (Probably because a single core doesn’t produce as much heat, so it can run faster, and because the later times have to wait for the slowest chain to finish?)

Predicting speedups from parallelization is always a bit tricky. Maybe what helps is a bit of background about what’s going one behind the scenes:

  1. When we call pm.sample we first compile the pytensor functions needed for the sampler. This can easily take a couple of seconds, but there is some caching going on, so usually it will be relatively fast.
  2. We start one new process per chain (not core!), serialize the sampler functions, send them to the new processes and each of those will unserialize them then. The processes then just wait to be told to do something.
  3. The main process tells at most cores many of the process to generate a new draw, and waits for one of the worker processes to get back to it. The started worker processes then all run at the same time, and call the logp_grad function a very variable number of times (something between 2 and 1000 times) to generate a draw, and send that draw to the main process. Then they wait again. When the main process gets one of the draws it will first tell the worker process to start computing the next one, and then compute the deterministics, and store the draw in the trace object. It repeats that until all chains are finished (when one of the chains finishes, it will replace it by one of the waiting chains - cores processes that are still waiting).
  4. We then compute convergence checks and return the trace object in the main thread.

If we want to find out how long the whole thing takes it makes sense to think about where the bottleneck will be. It could actually be the first, second or last step, in which case all the parallelization won’t help that much. If it is the 3rd step, it could be for two reasons: If n_logp_calcs_per_draw * time(logp_calc) < n_chains * (time(calc_deterministics) + time(store_in_trace)), then the bottleneck will be the main thread, that can’t keep up with all the new draws from the worker processes. The worker processes will then just wait most of the time, and the parallelization won’t help much. Otherwise, the main thread will often wait for the draws, and the worker processes will actually run in parallel most of the time, and it will help.

I think that pretty much explains what we are seeing?

2 Likes