NUTS sampling with GPU acceleration in PyMC4

Hi there, I have set up a Hierarchical Bayes model for choice data (on AWS Sagemaker) and am able to use NUTS sampler in PyMC4 to take samples. Now I’m trying to run the sampling on GPU. My Sagemaker instance has GPU available. I tried to get it to work using a .aesara.rc file by setting device=cuda/cuda0/gpu but none of these work and only device=cpu works. I hear that GPU acceleration is supposed to be straightforward on PyMC4. Could someone please guide me in the right direction for that? More details of my model and sampling can be found here:

I can also provide more information if need be.
Thanks in advance!

Have you followed the instructions on setting up your environment to use the GPU through the JAX backend?

You can check that JAX has found the GPU using:

In [1]: import jax
In [2]: jax.default_backend()
Out[2]: 'gpu'
In [3]: jax.devices()
Out[3]: [GpuDevice(id=0, process_index=0)]

Yes, recently I followed those instructions and was able to confirm that gpu was being accessed, as you have suggested to check. Thank you!

Great, and you’re calling pymc.sampling_jax.sample_numpyro_nuts() to do your sampling? Sorry if these seem like basic questions, just trying to make sure we’re on the same page.

No worries :slight_smile: Yes, that’s what I do.

And how have you confirmed that the GPU isn’t being utilized?

I have realized that GPU “is” being utilized. Please see earlier post on this thread. Apparently, once JAX has GPU access, you’re done. PyMC4 utilizes GPU via JAX.

Great! Glad you got it working. And yeah, when I was setting up to use the GPU, I was preparing for a fight with the computer, too. It was far easier than I expected.

It was in fact a huge challenge for me.