Set up environment for JAX sampling with GPU support in PyMC v4

Hi, I tried to follow the above instructions and ran into many issues, however as a general framework the instructions were very helpful. The way I finally got things to work for myself were as follows:

Graphics driver:
I installed the driver suggested by ubuntu itself (via “Software and Updates”)

CUDA:
I used instructions in

to install cuda-11.4

cuDNN:
I downloaded the tgz archive from the cuDNN website and followed the instructions in

However, when executing the lines under “Copy the following files into the CUDA toolkit directory.” I changed all cuda to cuda-11.4

JAX and JAXLIB:
I followed the instructions in

to install compatible versions of these libraries

Anaconda:
Installed it from:
https://docs.anaconda.com/anaconda/install/linux/

PyMC4:
I followed:
https://www.pymc.io/projects/docs/en/latest/installation.html

to install PyMC4, but I stopped short of running the last command (i.e. I didn’t run “pip install blackjax”) because I thought it would probably with my jax/jaxlib installations.

After all of these when I open a jupyter notebook under my pymc_env virtual environment and run:

import jax
import jaxlib
print(jax.version)
print(jaxlib.version)

I get:
0.3.14
0.3.14

And also from:
jax.default_backend()

I get:
'gpu

So, seems like jax has access to gpu, but when I run sampling via pm.sampling_jax.sample_numpyro_nuts it’s pretty slow and I don’t think it’s running on gpu.

Any Advice will be appreciated.

4 Likes