Hi,
I have been trying to speed up the sampling of my models and wanted to use the multiprocessing functionality in the pm.sample() function.
In the docs it says:
mp_ctx
multiprocessing.context.BaseContent
A multiprocessing context for parallel sampling. See multiprocessing documentation for details.
I have tried adding a mp context to the model and it seems to have no impact at all. It doesn’t error, it just runs at the exact same speed. This is the way I am trying to run it:
import multiprocessing as mp
ctx_in_main = mp.get_context()
with self.model:
self.idata = pm.sample(draws=draws, tune=tune, target_accept=target_accept,
cores=cores, mp_ctx=ctx_in_main)
Am I missing something simple? I have tried on an AWS linux machine (with many free cores) and on my MacBook Pro M2 and both don’t seem to have any impact. I also tried all of the spawn, fork and fork_server start methods.
Any guidance at all appreciated!
Thanks,
Luke