I have mixed-effect logistic regression model with many random terms and GPU:s with relatively small RAM. For even larger problems the problem would be relevant even on GPU:s with large RAM.
(1) The problem, in a nutshell, is that pymc.sampling_jax.sample_numpyro_nuts() does not offload accumulated samples from the GPU until sampling has finished, which makes the RAM of the GPU the limiting factor for how many samples one can sample in one run.
(2) Adding to this problem, when multiple chains are generated on different GPU:s all samples need to fit into the RAM of one GPU.
I have used the configuration that I believe should make keeping samples on the GPU pointless (in particular postprocessing_backend="cpu":
Why are old samples kept on the GPU? Is it uncommon to run out of RAM (perhaps my models are unusually large)? How difficult would it be to implement some periodic off-loading of old samples from the GPU in order to remove GPU-RAM as the limiting factor for number of samples that could be collected in one run? For example, when 500 draws have been collected, off-load them to CPU RAM, and then store draws in the same buffer of GPU-RAM that is now not in use anymore.
Is this something that can be made by changing the pymc codebase, or should I redirect this feature request to the numpyro community?
Otherwise, I think you’ll have more luck reaching out to the numpyro community. If you learn something we could be doing differently at the PyMC level let us know!
When I use more than 1 GPU I use “parallel”, when using a single a GPU I use “sequential”. Would it make sense to use “sequential” when you have more than 1 GPU?
As for memory allocation I use XLA_PYTHON_CLIENT_ALLOCATOR=platform so only the RAM required is allocated.
Thanks for your reply, I’ll ask the numpyro community!
I found this snippet which indicates that it is possible to fetch samples periodically, since mcmc() in numpyro can be restarted in its latest state:
mcmc = MCMC(NUTS(test1), 100, 100)
for i in range(10):
print("\n"+GPU_mem_state())
mcmc.run(random.PRNGKey(i))
samples = mcmc.get_samples()
trace = [onp.atleast_1d(onp.asarray(f)) for f in samples]
del samples
mcmc._warmup_state = mcmc._last_state
gc.collect()