Solution was easier than expected:
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
However, checking if the GPU has been found I get the following error:
import jax
jax.default_backend()
INFO - 04/14/23 18:17:47 - 0:00:09 - Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Host CUDA Interpreter
INFO - 04/14/23 18:17:47 - 0:00:09 - Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO - 04/14/23 18:17:47 - 0:00:09 - Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
'gpu'
jax.devices()
[GpuDevice(id=0, process_index=0)]
Does that mean the installation was not succesful? I do get additionally the following:
File "/data/projects/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/pymc3/sampling_jax.py", line 137, in sample_numpyro_nuts
fns = jax_funcify(fgraph)
File "/data/projects/miniconda3/envs/pm3envJAX/lib/python3.8/functools.py", line 875, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 676, in jax_funcify_FunctionGraph
jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 676, in <listcomp>
jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 155, in compose_jax_funcs
input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 155, in compose_jax_funcs
input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 155, in compose_jax_funcs
input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
[Previous line repeated 11 more times]
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 121, in compose_jax_funcs
jax_return_func = jax_funcify(out_node.op)
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/functools.py", line 875, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 197, in jax_funcify
raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
NotImplementedError: No JAX conversion for the given `Op`: AllocDiag{offset=0, axis1=0, axis2=1}
Version numpyro= 1.22.3
Version jax=0.4.8
Version jaxlib=0.4.7