Set up environment for JAX sampling with GPU support in PyMC v4

Hi everyone,

This week, I have spent sometimes to re-install my dev environment, as I need to change to a new hard-drive. So I make a note on the steps I have done, hope that it may be useful for others, who want to run PyMC v4 with GPU support for Jax sampling. The step-by-step as follow:

1. Install Ubuntu 20.04.4 LTS (Focal Fossa)

The latest Ubuntu version is 22.04, but I’m a little bit conservative, so decided to install version 20.04. I download the 64-bit PC (AMD64) desktop image from here.

I made a Bootable USB using Rufus with the above ubuntu desktop .iso image. You can check this video How to Make Ubuntu 20.04 Bootable USB Drive. I assume that you have a NVIDIA GPU card on your local machine, and you know how to install ubuntu from a bootable USB. If not, you can just search it on Youtube.

2. Install NVIDIA Driver, CUDA 11.4, cuDNN v8.2.4

According to Jax’s guidelines, to install GPU support for Jax, first we need to install CUDA and CuDNN.

To do that, I follow the Installation of NVIDIA Drivers, CUDA and cuDNN from this guideline (Kudo the author Ashutosh Kumar for this).

One note is that we may not be able to find a specific version of NVIDIA Drivers on this step. Instead, we can go to this url: to download our specific driver version. For my case, I download the file at this link:

After successfully following these steps in the guideline, we can run nvidia-smi and nvcc --version commands to verify the installation. In my case, it will be somethings like this:

3. Install Jax with GPU supports

Following the Jax’s guidelines, after installing CUDA and CuDNN, we can using pip to install Jax with GPU support.

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f

Check if GPU device is available in Jax

We can then run Ipython or python and using these following commands to check.

In [1]: import jax
In [2]: jax.default_backend()
Out[2]: 'gpu'
In [3]: jax.devices()
Out[3]: [GpuDevice(id=0, process_index=0)]

That’s it. We have successfully installed Jax with GPU support. Now, we can run Jax-based sampling pm.sampling_jax.sample_numpyro_nuts(...) in PyMC v4 with the GPU capability.

Feel free to ask any questions here if you face any difficulty in these above steps.



This is great, can you add this to the pymc wiki in a similar style to the installation instructions?

1 Like

Yes sure, I will do that soon.