Pymc3 on GPU using jax

Hi everyone,

I am trying to create a conda environment using pymc3 with jax following this link. However, it gives me the following error:

Collecting git+https://github.com/pymc-devs/pymc3.git@pymc3jax
Cloning https://github.com/pymc-devs/pymc3.git (to revision pymc3jax) to /tmp/pip-req-build-s6d8qk5m
Running command git clone --quiet https://github.com/pymc-devs/pymc3.git /tmp/pip-req-build-s6d8qk5m
WARNING: Did not find branch or tag 'pymc3jax', assuming revision or ref.
Running command git checkout -q pymc3jax
error: pathspec 'pymc3jax' did not match any file(s) known to git.

which according to this, makes sense, however, could someone please point me to any alternative instruction I could follow how to install pymc3 (or any other) with jax to be able to run my code on GPU?

I also found this one link saying

This NB requires the master of [Theano-PyMC](https://github.com/pymc-devs/Theano-PyMC), the [pymc3jax branch of PyMC3](https://github.com/pymc-devs/pymc3/tree/pymc3jax), as well as JAX, TFP-nightly and numpyro.

similarly the pymc3jax branch does not exist.

Thanks in advance! :slight_smile:

Solution was easier than expected: :slight_smile:

conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia

However, checking if the GPU has been found I get the following error:

import jax
jax.default_backend()
INFO - 04/14/23 18:17:47 - 0:00:09 - Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Host CUDA Interpreter
INFO - 04/14/23 18:17:47 - 0:00:09 - Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO - 04/14/23 18:17:47 - 0:00:09 - Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
'gpu'
jax.devices()
[GpuDevice(id=0, process_index=0)]

Does that mean the installation was not succesful? I do get additionally the following:

      File "/data/projects/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/pymc3/sampling_jax.py", line 137, in sample_numpyro_nuts
fns = jax_funcify(fgraph)
File "/data/projects/miniconda3/envs/pm3envJAX/lib/python3.8/functools.py", line 875, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 676, in jax_funcify_FunctionGraph
jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 676, in <listcomp>
jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 155, in compose_jax_funcs
input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 155, in compose_jax_funcs
input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 155, in compose_jax_funcs
input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
[Previous line repeated 11 more times]
 File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 121, in compose_jax_funcs
jax_return_func = jax_funcify(out_node.op)
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/functools.py", line 875, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/data/miniconda3/envs/pm3envJAX/lib/python3.8/site-packages/theano/link/jax/jax_dispatch.py", line 197, in jax_funcify

raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
NotImplementedError: No JAX conversion for the given `Op`: AllocDiag{offset=0, axis1=0, axis2=1}

Version numpyro= 1.22.3
Version jax=0.4.8
Version jaxlib=0.4.7

using pymc 5.3.0 solves the problem :slight_smile: