Some questions on GPU based sampling

Hi all,

up to now, I was using only CPU based sampling and was somehow hesitant to set up an environment which might support GPU based sampling.

I found the article Set up JAX sampling with GPUs in PyMC which gives some hints about the process of setting up such an environment. However, the article is not the newest and might not reflect current versions and developments.

I just ordered a new PC with an RTX 5070 Ti, which as far as I know requires CUDA 12.0.

I now have the following questions:

  1. Which sampling library currently is the “state of the art” library for usage with PyMC to empower GPU based sampling? jax? numpyro? Others?
  2. Is CUDA 12.0 - to your knowledge - supported by these samplers?
  3. If so, do you expect any problems with using such a “new” GPU version like the RTX 5070 Ti?

I plan to use the GPU support in a Windows 11 environment, however the jax documentation suggests that this is still not possible natively, so I would have to use WSL like described in the above mentioned article. Is there any option to use a GPU based sampler on Windows natively?

Thanks for sharing any experiences. I will give my own report as soon as the PC arrives, which is probably in 2-3 weeks.

Best regards
Matthias

I started using GPU based sampling with pymc 5.20.1, even tough it was possible before this as well. I have tried it on two different GPUs, an old one with 4GB of VRAM and the GeForce 3090. Furthermore, I am running Ubuntu 24 in WSL on Windows. What worked for me is this:

  1. Install numpyro
    pip install numpyro
  2. Install the jax backend with cuda (yes, cuda12 is supported by numpyro and jax)
    pip install -U "jax[cuda12]"
  3. Specify the numpyro sampler in pm.samle
trace = pm.sample(
    samples=1000,
    chains=4,
    cores=1, # I set this to 1 because I have one GPU
    nuts_sampler="numpyro",
    nuts_sampler_kwargs={
        "chain_method": "vectorized" # this will run all chains concurrently
                                     # "parallel" is useful if you have
                                     # more devices
    },
    initvals=...,
)

For some reason, I never managed to run the blackjax sampler :thinking:

When comparing speed, sampling 1000 samples for the model I am working on took:

  1. 4 x 10 mins = 40 mins on the GeForce 3090
  2. 4 x 1 hour = 4 hours on the older GPU
  3. 12 hours on my laptop with an Intel Core I7 1065G7 CPU (surely it is faster on better CPUs than this)
1 Like

I am a bit biased, but I would recommend using nutpie instead of numpyro or blackjax. I my experience it outperforms both of those, on both CPU and GPU, but especially on the GPU.

You have to select the jax backend when compiling the model though:

import nutpie

compiled = nutpie.compile_pymc_model(model, backend="jax")
trace = nutpie.sample(compiled)

You can also specify if pytensor or jax should take the necessary derivatives by passing gradient_backend="jax" or gradient_backend="pytensor".

2 Likes

You can also achieve that through pm.sample(nuts_sampler="nutpie", nuts_sampler_kwargs=dict(backend="jax"))

1 Like