Hello! I have a very large set of training data that I would like to use pymc to fit a model against. My issue: I do not have a sufficiently large GPU to load all of the data i to memory (I am using blackjax NUTS on the GPU).
So: I would like to “stream” the data to the model using something similar (or even exactly like) a PyTorch Dataloder.
Does anyone have an example of doing this? If not, is this even possible with the current pymc model instantiation?
Variational Inference supports minibatching, see here. I’m not sure how (or if) this works with compiling to JAX. My understanding is that you can’t do MCMC with minibatches because it would violate the detailed balance assumption used to ensure convergence to the true posterior.
I need to look at the NUTS source. Presumably there is a for loop of some sort wherein the algorithm iterates over the input data. I just want to not have to have that entire object in GPU memory all at once.
So I could know at runtime how many points there are, I would just add the points in batches (for every iteration).
With NUTS you would need to iterate over the whole dataset in every gradient evaluation, or ~100s of times per step. Transferring data to the GPU at this frequency would probably bring sampling speed down to a halt even if it was possible.
Anyway I am not familiar with a way that makes this possible even if slowly. You might need to check numpyro and see if they have any functionality like this. If they do we can investigate if it’s possible to provide access to it from PyMC