Nevermind! Fixed by reinstalling jax with:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Sampling time decreased too which is nice.
Nevermind! Fixed by reinstalling jax with:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Sampling time decreased too which is nice.