I have mixed-effect logistic regression model with many random terms and GPU:s with relatively small RAM. For even larger problems the problem would be relevant even on GPU:s with large RAM.
(1) The problem, in a nutshell, is that
pymc.sampling_jax.sample_numpyro_nuts() does not offload accumulated samples from the GPU until sampling has finished, which makes the RAM of the GPU the limiting factor for how many samples one can sample in one run.
(2) Adding to this problem, when multiple chains are generated on different GPU:s all samples need to fit into the RAM of one GPU.
I have used the configuration that I believe should make keeping samples on the GPU pointless (in particular
idata_kwargs=dict(log_likelihood=False), postprocessing_chunks = 5, postprocessing_backend="cpu"
Why are old samples kept on the GPU? Is it uncommon to run out of RAM (perhaps my models are unusually large)? How difficult would it be to implement some periodic off-loading of old samples from the GPU in order to remove GPU-RAM as the limiting factor for number of samples that could be collected in one run? For example, when 500 draws have been collected, off-load them to CPU RAM, and then store draws in the same buffer of GPU-RAM that is now not in use anymore.
Is this something that can be made by changing the pymc codebase, or should I redirect this feature request to the numpyro community?
Kind regards, Hans Ekbrand