Does anyone know of any guides or docs for parallel sampling in current version of PYMC (mp_ctx)


I have been trying to speed up the sampling of my models and wanted to use the multiprocessing functionality in the pm.sample() function.
In the docs it says:

mp_ctx multiprocessing.context.BaseContent
A multiprocessing context for parallel sampling. See multiprocessing documentation for details.

I have tried adding a mp context to the model and it seems to have no impact at all. It doesn’t error, it just runs at the exact same speed. This is the way I am trying to run it:

import multiprocessing as mp
ctx_in_main = mp.get_context()
with self.model:
    self.idata = pm.sample(draws=draws, tune=tune, target_accept=target_accept, 
                                   cores=cores, mp_ctx=ctx_in_main)

Am I missing something simple? I have tried on an AWS linux machine (with many free cores) and on my MacBook Pro M2 and both don’t seem to have any impact. I also tried all of the spawn, fork and fork_server start methods.

Any guidance at all appreciated!


What are you looking for? Usually the only thing that you have to tweak is the number of cores

Basically looking to run 4 chains on 12 cores, rather than 4 chains on 4 cores. If you increase the number of cores then you are just increasing the number of chains no?

Yes, cores is only related to the number of chains. MCMC is a sequential algorithm, so after the chains are parallelized, there’s no further gains that can be realized via parallelism for MCMC. That said, within each MCMC step there can be, and often is, additional multiprocessing that goes on, but that happens independently of the cores argument. See here for relevant discussions.

1 Like