This week, I have spent sometimes to re-install my dev environment, as I need to change to a new hard-drive. So I make a note on the steps I have done, hope that it may be useful for others, who want to run PyMC v4 with GPU support for Jax sampling. The step-by-step as follow:
The latest Ubuntu version is 22.04, but I’m a little bit conservative, so decided to install version 20.04. I download the
64-bit PC (AMD64) desktop image from here.
I made a Bootable USB using Rufus with the above ubuntu desktop .iso image. You can check this video How to Make Ubuntu 20.04 Bootable USB Drive. I assume that you have a NVIDIA GPU card on your local machine, and you know how to install ubuntu from a bootable USB. If not, you can just search it on Youtube.
According to Jax’s guidelines, to install GPU support for Jax, first we need to install CUDA and CuDNN.
One note is that we may not be able to find a specific version of NVIDIA Drivers on this step. Instead, we can go to this url: https://download.nvidia.com/XFree86/Linux-x86_64/ to download our specific driver version. For my case, I download the file
NVIDIA-Linux-x86_64-470.82.01.run at this link: https://download.nvidia.com/XFree86/Linux-x86_64/470.82.01/
After successfully following these steps in the guideline, we can run
nvcc --version commands to verify the installation. In my case, it will be somethings like this:
Following the Jax’s guidelines, after installing CUDA and CuDNN, we can using pip to install Jax with GPU support.
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. # Note: wheels only available on linux. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
We can then run
python and using these following commands to check.
In : import jax In : jax.default_backend() Out: 'gpu' In : jax.devices() Out: [GpuDevice(id=0, process_index=0)]
That’s it. We have successfully installed Jax with GPU support. Now, we can run Jax-based sampling
pm.sampling_jax.sample_numpyro_nuts(...) in PyMC v4 with the GPU capability.
Feel free to ask any questions here if you face any difficulty in these above steps.