JAX Docker image was publicly released

JAX dropped an Nvidia Docker image, PyMC with GPU on Docker just got easier.

For anyone to recreate:
A basic Dockerfile to run interactive:

FROM nvcr.io/nvidia/jax:23.08-t5x-py3

RUN apt-get update && apt-get install -y --no-install-recommends \
  build-essential \
  git \
  curl \
  vim \
  python3-dev \
  python3-pip

RUN python3 -V

WORKDIR /task-irt

RUN pip install --upgrade pip

COPY requirements.txt ./

RUN pip install -r requirements.txt

COPY . .

And Requirements

arviz==0.15.1
boto3==1.28.2
botocore==1.31.2
cachetools==5.3.1
cloudpickle==2.2.1
cons==0.4.6
contourpy==1.1.0
cycler==0.11.0
environs==9.5.0
etuples==0.3.9
fastprogress==1.0.3
filelock==3.12.2
fonttools==4.40.0
h5netcdf==1.1.0
h5py==3.9.0
jax[cuda11_pip]==0.4.13
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jmespath==1.0.1
kiwisolver==1.4.4
logical-unification==0.4.6
matplotlib==3.7.1
miniKanren==1.0.3
ml-dtypes==0.2.0
multipledispatch==1.0.0
numpy==1.24.0
numpyro==0.12.1
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvcc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==8.9.2.26
nvidia-cufft-cu11==10.9.0.58
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
opt-einsum==3.3.0
packaging==23.1
pandas==2.0.2
Pillow==9.5.0
Pygments==2.4.1
pymc==5.9.0
pyparsing==3.1.0
pytensor==2.17.0
python-dateutil==2.8.2
pytz==2023.3
s3transfer==0.6.1
scipy==1.10.0
six==1.16.0
toolz==0.12.0
tqdm==4.65.0
traitlets==5.1.1
typing_extensions==4.6.3
tzdata==2023.3
urllib3==1.26.16
xarray==2023.1.0
xarray-einstats==0.5.1

edited to update pymc and pytensor versions

2 Likes

That’s awesome - thanks for sharing!

I notice that the pinned pymc 5.0.2 is very old, would be nice to update that.

Just updated versions of pymc and pytensor, rebuilt and tested that nvidia-smi and jax detect the device.
All seems well

1 Like