Set up environment for JAX sampling with GPU support in PyMC v4

Hi everyone,

This week, I have spent sometimes to re-install my dev environment, as I need to change to a new hard-drive. So I make a note on the steps I have done, hope that it may be useful for others, who want to run PyMC v4 with GPU support for Jax sampling. The step-by-step as follow:

1. Install Ubuntu 20.04.4 LTS (Focal Fossa)

The latest Ubuntu version is 22.04, but I’m a little bit conservative, so decided to install version 20.04. I download the 64-bit PC (AMD64) desktop image from here.

I made a Bootable USB using Rufus with the above ubuntu desktop .iso image. You can check this video How to Make Ubuntu 20.04 Bootable USB Drive. I assume that you have a NVIDIA GPU card on your local machine, and you know how to install ubuntu from a bootable USB. If not, you can just search it on Youtube.

2. Install NVIDIA Driver, CUDA 11.4, cuDNN v8.2.4

According to Jax’s guidelines, to install GPU support for Jax, first we need to install CUDA and CuDNN.

To do that, I follow the Installation of NVIDIA Drivers, CUDA and cuDNN from this guideline (Kudo the author Ashutosh Kumar for this).

One note is that we may not be able to find a specific version of NVIDIA Drivers on this step. Instead, we can go to this url: to download our specific driver version. For my case, I download the file at this link:

After successfully following these steps in the guideline, we can run nvidia-smi and nvcc --version commands to verify the installation. In my case, it will be somethings like this:

3. Install Jax with GPU supports

Following the Jax’s guidelines, after installing CUDA and CuDNN, we can using pip to install Jax with GPU support.

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f

Check if GPU device is available in Jax

We can then run Ipython or python and using these following commands to check.

In [1]: import jax
In [2]: jax.default_backend()
Out[2]: 'gpu'
In [3]: jax.devices()
Out[3]: [GpuDevice(id=0, process_index=0)]

That’s it. We have successfully installed Jax with GPU support. Now, we can run Jax-based sampling pm.sampling_jax.sample_numpyro_nuts(...) in PyMC v4 with the GPU capability.

Feel free to ask any questions here if you face any difficulty in these above steps.



This is great, can you add this to the pymc wiki in a similar style to the installation instructions?

1 Like

Yes sure, I will do that soon.

Hi, I tried to follow the above instructions and ran into many issues, however as a general framework the instructions were very helpful. The way I finally got things to work for myself were as follows:

Graphics driver:
I installed the driver suggested by ubuntu itself (via “Software and Updates”)

I used instructions in

to install cuda-11.4

I downloaded the tgz archive from the cuDNN website and followed the instructions in

However, when executing the lines under “Copy the following files into the CUDA toolkit directory.” I changed all cuda to cuda-11.4

I followed the instructions in

to install compatible versions of these libraries

Installed it from:

I followed:

to install PyMC4, but I stopped short of running the last command (i.e. I didn’t run “pip install blackjax”) because I thought it would probably with my jax/jaxlib installations.

After all of these when I open a jupyter notebook under my pymc_env virtual environment and run:

import jax
import jaxlib

I get:

And also from:

I get:

So, seems like jax has access to gpu, but when I run sampling via pm.sampling_jax.sample_numpyro_nuts it’s pretty slow and I don’t think it’s running on gpu.

Any Advice will be appreciated.


Hi @payamphysics

If you put your sampling_jax codes in a file named, you can run this command in bash shell to check if it can use gpu:


1 Like

Hi @DanhPhan,

Thank you for your suggestion. I actually did what you said, but in the feedback, I get in the terminal I don’t see an explicit confirmation of the GPU being used. Does the mere fact that the device is specified as GPU and the code runs fine mean that the GPU is being used?
Here’s what I see, in response to running the following command:

$ JAX_PLATFORM_NAME=gpu python

/home/yixian/anaconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/jax/ UserWarning: JAX omnistaging couldn't be disabled: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see
  warnings.warn(f"JAX omnistaging couldn't be disabled: {e}")
/home/yixian/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/ UserWarning: This module is experimental.
  warnings.warn("This module is experimental.")
Compilation time =  0:00:04.611605
sample: 100%|█| 100/100 [00:36<00:00,  2.74it/s, 255 steps of size 2.39e-02. acc
Sampling time =  0:00:44.249387
Transforming variables...
Transformation time =  0:00:00.298734
Computing Log Likelihood...
Log Likelihood time =  0:00:00.844047
t2-t1 is 0.8344 minutes
1 Like

Hi, have you also run it with cpu option like $ JAX_PLATFORM_NAME=cpu python and check the time difference?

Also, how many GPUs do you have? In my case, I only have 1 GPU, so when I set 2 (or more) chains or cores when sampling trace with jax, it shows a message that I only has 1 GPU, so the two chains will be sampled in sequence, not parallel. From that I know that it is run with GPU :slight_smile:

But you’re right that there is no explicit confirmation that GPUs being used.

1 Like

Can you open a terminal and run the command nvidia-smi while the sampler is running? You should see the % GPU and memory usage that way.

1 Like

Yes, I was able to see the GPU usage that way. Thank you for the suggestion!

1 Like

Hi everyone,
Sorry to reopen the thread, but I believe this would be the right place to ask instead of creating a new thread.

So… I follow all the instructions here to start using the GPU, but despite Nvidia-smi shows the GPU, and also JAX, when running a very simple model (from Bayesian Analysis with Python - chap 2), it looks like the sampler is still using CPU, as the CPU workload increases to 50% while GPU remains close to 1% without moving, just like before JAX.

BTW: Interesting to notice that even before installing JAX and GPU support, running the same script in Ubuntu under WSL2 was 10x faster than in Windows on the same computer !!! I am very curious to see how fast will be when the GPU will be used for the sampling.

Thank you very much for any help

Hi, it´s me again…

I had some progress, as I forgot to use the jax sampler. So it uses the GPU… but was much slower than the CPU.

This is completely new territory for me:


import matplotlib.pyplot as plt
import scipy.stats as stats
import numpy as np
import pandas as pd
import seaborn as sns
import pymc as pm
import pymc.sampling_jax
import arviz as az
import jax
import jaxlib


WARNING (aesara.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
/home/slepetys/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/ UserWarning: This module is experimental.
  warnings.warn("This module is experimental.")






trials = 4
theta_real = 0.35  # unknown value in a real experiment
data = stats.bernoulli.rvs(p=theta_real, size=trials)

with pm.Model() as our_first_model:
    # a priori
    θ = pm.Beta('θ', alpha=1., beta=1.)
    # likelihood
    y = pm.Bernoulli('y', p=θ, observed=data)
    #idata = pm.sample(tune=10000, draws=50000, random_seed=123, return_inferencedata=True)
    idata =  pm.sampling_jax.sample_numpyro_nuts(tune=10000, draws=1000, chains=10, random_seed=123)


/home/slepetys/anaconda3/envs/pymc_env/lib/python3.10/site-packages/tqdm/ TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See
  from .autonotebook import tqdm as notebook_tqdm
/home/slepetys/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/ UserWarning: There are not enough devices to run parallel chains: expected 10 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(10)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  pmap_numpyro = MCMC(

… and sampling took minutes instead of seconds:

I am quite sure that I am doing something very silly here.

Thank you for your patience

It’s running chains sequentially instead of in parallel. If your model runs as fast in CPU as in GPU you would see a slowdown proportional to the number of chains. You can try to pass chain_method = "vectorized" to sample_numpyro_nuts


Thank you Ricardo,
I updated it, and it is running in parallel now.

In a very quick set of experiments, I realized that the gains from using JAX sampler with GPU start only when the dataset crosses some size threshold. The PyMC sampler using CPUs´ cores is much faster for small datasets.

A small but useful, speed-up was obtained moving from the Windows environment to Ubuntu under WSL2. After some tests, I concluded that the gain emerged from the overhead to compile under Windows, which is almost zero under Linux/Ubuntu.

All the best!

1 Like

Yes, GPU is not always the best backend. It depends on the datasize and operations required to compute the joint logp/dlogp

Hi @Roberto_Slepetys Could you elaborate more on the dataset size threshold?

It will be interesting to create a simulation and plot a chart that shows the performance of CPUs/GPUs and with/without-out Jax sampling on a range of data set sizes. Thanks

Hi DanhPhan,

I just ran a very simple simulation:

  1. First generated artificial data, such as the code above:
trials = 4
theta_real = 0.35  # unknown value in a real experiment
data = stats.bernoulli.rvs(p=theta_real, size=trials)
  1. Then I ran a model using either JAX[GPU] or the PyMC sampler using the CPU cores.

The comparison was not different from what was published by Matrin Ingram in the post: MCMC for big datasets: faster sampling with JAX and the GPU - PyMC Labs , which I believe is the same you would like to evaluate.

My new finding was that when you run the models using CPU cores, it is much faster to execute under Ubuntu even under WSL2 than under Windows due to the time required to instantiate the compiler. On average, an additional 18 seconds were required for compilation under Windows 10 compared to WSL2/Ubuntu, and from the Jupyter Lab perspective, there was no difference among the OS environments. Indeed, WSL2 allowed me to install JAX/GPU with Nvidia, which I was not able to do under MS Windows.

1 Like