My installation protocol for GPU based sampling on Windows (RTX 50*)

Dear Bayesians,

it might be of some help for some of you, today or later, to get an (up-to-date) installation protocol for GPU based sampling on Windows. I have seen some protocols based on pip here in the forum. However, I found that - hopefully - a purely conda based protocol seems to work.

First of all, make sure that the NVIDIA / CUDA drivers are installed properly by executing this on a PowerShell:

$ nvidia-smi
Sun Mar 16 09:18:12 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 572.70                 Driver Version: 572.70         CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 5070 Ti   WDDM  |   00000000:01:00.0  On |                  N/A |
|  0%   53C    P5             21W /  300W |   15663MiB /  16303MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

This was all pre-installed by my PC vendor, therefore I did not have to do anything.

Next, it is still - unfortunately - not possible to install the relevant libraries directly in a Windows Python environment. This means, you need to use a WSL / Ubuntu based Python.

To install WSL, do this in a PowerShell:

wsl --install

After installation, look which Ubuntu is available:

wsl --list --online

I made a decision for:

wsl --install -d Ubuntu-22.04

After installation, open the “Terminal” app and create a new tab with your Ubuntu running. The next step is to install a Conda based Python installation in this Ubuntu. I did this by the following steps:

wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
source ~/.bashrc

Make conda-forge available:

conda config --add channels conda-forge
conda config --set channel_priority strict

You can now create a new conda environment for your pymc stuff. As I had some trouble with Python 3.12, I finally used a 3.11 environment:

conda create -n jax-env python=3.11
conda activate jax-env

Install mamba to get things done faster:

conda install -n base -c conda-forge mamba

Now install the relevant libraries:

mamba install -c conda-forge pymc bambi arviz nutpie blackjax numpyro

And some additional stuff

mamba install -c conda-forge matplotlib seaborn scikit-learn gputil jupyter

The environment is now created. To use it in a Jupyter Notebook in a Windows IDE, you have to configure this as your runtime. I am using DataSpell from Jetbrains, and I guess the configuration would be similar in IntelliJ or PyCharm. However, I do not know how this would work in VS Code.

In your project configuration, you add a new interpreter and choose “On WSL”:

In the following dialog, you choose your installed Ubuntu:

In the following dialog, you choose “Conda environment” and you should find your configured environment in the dropdown:

After creation, you should find this interpreter in the list and can choose it for your project.

To use this interpreter for a Jupyter notebook, you can configure it as a “Managed Server”:

Now check some versions:

And here are some sampling times with a very simple model. I assume that GPU is not the right mode for a model with only one observed variable, because GPU will only pay out for much more dimensions. Or did I something wrong with the configuration?

true_mean_of_generating_process = 5
true_sigma_of_generating_process = 2

observed_data = np.random.normal(loc=true_mean_of_generating_process,
                                 scale=true_sigma_of_generating_process,
                                 size=100)

with pm.Model() as model_mean:
    mu = pm.Normal("mu", mu=0, sigma=10)
    sigma = pm.HalfNormal("sigma", sigma=10)

    likelihood = pm.Normal("likelihood",
                           mu=mu,
                           sigma=sigma,
                           observed=observed_data)

    # CPU
    # runtime: 3.28 seconds (AMD Ryzen 9 9900X)
    # idata_mean_cpu = pm.sample(
    #     draws=8000,
    #     tune=1000,
    #     chains=4,
    #     progressbar=False)

    # GPU with numpyro
    # runtime: 1m 49 seconds
    # idata_mean_numpyro = pm.sample(
    #     draws=8000,
    #     tune=1000,
    #     chains=4,
    #     progressbar=False,
    #     nuts_sampler="numpyro")

    # GPU with blackjax
    # needs the argument "chain_method" : "vectorized" to run without error
    # runtime: 20 seconds
    # idata_mean_blackjax = pm.sample(
    #     draws=8000,
    #     tune=1000,
    #     chains=4,
    #     progressbar=False,
    #     nuts_sampler="blackjax",
    #     nuts_sampler_kwargs={"chain_method" : "vectorized"})

    # GPU with nutpie
    # runtime: 1m 31 seconds
    # idata_mean_nutpie = pm.sample(
    #     draws=8000,
    #     tune=1000,
    #     chains=4,
    #     progressbar=False,
    #     nuts_sampler="nutpie",
    #     nuts_sampler_kwargs={"backend" : "jax"})

Best regards
Matthias

You should check that the GPU is installed and configured by importing jax directly, like:

import jax
jax.devices() # [CudaDevice(id=0)]

# Confirm you can put some data on the GPU
jax.device_put(jax.numpy.ones(1), device=jax.devices('gpu')[0])

After you import jax, you should also see the memory usage at 80-90% in nvidia-smi, because xla pre-allocates a bunch of memory.

As for sampling, GPUs are not magic speedup boxes. They are very good at doing the same computations over and over, especially when those computations can be done in parallel. NUTS is not such a computation, because the computation of hamiltonian trajectories is a sequential computation with embedded control flow (leapfrog integrator steps, with checks for a u-turn).

For a model to get a speedup doing NUTS on GPU, you need the logp/gradient computations (done as part of each leapfrog step) to be expensive and amenable to GPU acceleration (e.g. involving large matrix multiplications). Otherwise, you’ll be spending most of your time in (sequential) leapfrog steps, where the GPU isn’t going to help much.

To make good use of GPUs, people are turning to other algorithms, such as VI, pathfinder, or normalizing flows. Experts like @aseyboldt and @bob-carpenter might have more to say about this.

Sure, did that …

Also, during sampling, the GPU is running on 100% according to the task manager.

I guess the GPU capabilities really only overpace the CPU when you have lots of variables and lots of data.

When I increase the number of observed data from 100 (see example above) to 1 million, the measured sampling times (each 4 chains with 8000 draws) are as follows:

# CPU
# 1 million: 4m 6 seconds

# GPU with numpyro
# 1 million: 2m 33s

# GPU with blackjax
# 1 million: 1m 8 seconds

# GPU with nutpie
# 1 million 1m 20 seconds
1 Like