Hi everyone,
This week, I have spent sometimes to re-install my dev environment, as I need to change to a new hard-drive. So I make a note on the steps I have done, hope that it may be useful for others, who want to run PyMC v4 with GPU support for Jax sampling. The step-by-step as follow:
1. Install Ubuntu 20.04.4 LTS (Focal Fossa)
The latest Ubuntu version is 22.04, but I’m a little bit conservative, so decided to install version 20.04. I download the 64-bit PC (AMD64) desktop image
from here.
I made a Bootable USB using Rufus with the above ubuntu desktop .iso image. You can check this video How to Make Ubuntu 20.04 Bootable USB Drive. I assume that you have a NVIDIA GPU card on your local machine, and you know how to install ubuntu from a bootable USB. If not, you can just search it on Youtube.
2. Install NVIDIA Driver, CUDA 11.4, cuDNN v8.2.4
According to Jax’s guidelines, to install GPU support for Jax, first we need to install CUDA and CuDNN.
To do that, I follow the Installation of NVIDIA Drivers, CUDA and cuDNN from this guideline (Kudo the author Ashutosh Kumar for this).
One note is that we may not be able to find a specific version of NVIDIA Drivers on this step. Instead, we can go to this url: https://download.nvidia.com/XFree86/Linux-x86_64/ to download our specific driver version. For my case, I download the file NVIDIA-Linux-x86_64-470.82.01.run
at this link: https://download.nvidia.com/XFree86/Linux-x86_64/470.82.01/
After successfully following these steps in the guideline, we can run nvidia-smi
and nvcc --version
commands to verify the installation. In my case, it will be somethings like this:
3. Install Jax with GPU supports
Following the Jax’s guidelines, after installing CUDA and CuDNN, we can using pip to install Jax with GPU support.
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
Check if GPU device is available in Jax
We can then run Ipython
or python
and using these following commands to check.
In [1]: import jax
In [2]: jax.default_backend()
Out[2]: 'gpu'
In [3]: jax.devices()
Out[3]: [GpuDevice(id=0, process_index=0)]
That’s it. We have successfully installed Jax with GPU support. Now, we can run Jax-based sampling pm.sampling_jax.sample_numpyro_nuts(...)
in PyMC v4 with the GPU capability.
Feel free to ask any questions here if you face any difficulty in these above steps.
Cheers!