Is there a recommended way to harness the benefits of server nodes with a high number of cores, to speed up model fits in pymc3?
In my specific scenario, I ideally want a a small number (~4) of fairly long (~100k samples) chains, so specifying a high number of cores in pm.sample is not appropriate.
One option is to run a high number of short single-chain fits, specifying a large number of cores in pm.sample, and amalgamate these chains later. e.g. pm.sample(1000,…,chains=1,cores=400,…)
However given that each of these mini chains would require tuning, this seems a bit redundant…
Is it appropriate/possible to save a sampler state after tuning and use this with multiple small single-chain fits instead?
e.g., to do a pseudo 4-chain, 100k samples per chain fit:
run 4 tuning traces with no actual samples, e.g.
tune_trace_1 = pm.sample(0,tune=50000,chains=1)
tune_trace_2 = pm.sample(0,tune=50000,chains=1)
tune_trace_3 = pm.sample(0,tune=50000,chains=1)
tune_trace_4 = pm.sample(0,tune=50000,chains=1)
then save/load the states to apply to:
sample_trace_1 = pm.sample(1000,tune=<settings from tune_trace_1>,cores=100)
sample_trace_2 = pm.sample(1000,tune=<settings from tune_trace_2>,cores=100)
sample_trace_3 = pm.sample(1000,tune=<settings from tune_trace_3>,cores=100)
sample_trace_4 = pm.sample(1000,tune=<settings from tune_trace_4>,cores=100)
You then amalgamate the 100 chains from each trace into 100k arrays, run your ess, rhat etc. And then ultimately be left with 400k sample parameter posteriors?