Installation guide for pymc3jax

This may be very trivial but i have troubles installing pymc3jax, the module that allows to use jax sampler with pymc3 syntax.

I am on an Ubuntu machine with miniconda.
I tried runnig the command

pip install git+

But it gives me this message

Collecting git+

Cloning (to revision pymc3jax) to /tmp/pip-req-build-tabxbbvt

Running command git clone -q /tmp/pip-req-build-tabxbbvt

WARNING: Did not find branch or tag 'pymc3jax', assuming revision or ref.
Running command git checkout -q pymc3jax

error: pathspec 'pymc3jax' did not match any file(s) known to git

WARNING: Discarding git+ Command errored out with exit status 1: git checkout -q pymc3jax Check the logs for full command output.

ERROR: Command errored out with exit status 1: git checkout -q pymc3jax Check the logs for full command output.

Can someone give me some instructions on how to install pymc3jax?
Thank you in advance!

Is there any reference that pointed you to do that?

There is no pymc3jax module, the command above is trying to install a specific branch of the pymc repo which seems to no longer exist. If you want to use jax I’d recommend installing the pymc development version with

pip install git+

which integrates better with jax and can sample with both numpyro and blackjax samplers (both written in jax). The documentation for the latest version, inclusing the numpyro and blackjax sampling methods is available at Samplers — PyMC dev documentation


Yeah those are some very outdated instructions, where did you find them?

You need to install pymc 4.0.0b3 with aesara to use JAX.

1 Like

Thanks for the answers!
I found the instruction here:

So would the JAX sampling work, if I install pymc v4 following the installation guide that i can find at the link

That is, if I run the commands

git clone
cd pymc
conda env create -f ./conda-envs/environment-dev-py39.yml
conda activate pymc-dev-py39
pip install .

Thank again for your time. Hopefully this is helpful for other people as well.


I can confirm that the specific Linux install instructions sourcing directly from the repo will install the following (among all other dependencies - I tested these install steps on a Lubuntu Impish 21.10 laptop):

jax                       0.3.2              pyhd8ed1ab_0    conda-forge
jaxlib                    0.3.0            py39h3498573_3    conda-forge
aesara                    2.5.1            py39h788985e_0    conda-forge
arviz                     0.11.4             pyhd8ed1ab_0    conda-forge
python                    3.9.10          h85951f9_2_cpython    conda-forge

As @twiecki and @OriolAbril mentioned, this will permit you to use jax samplers.


It worked. I just had to install numpyro as well.

As I would like to fit some data with a model that cannot be expressed with the usual pymc3 operations, I wanted to use The JAX sampler to speed up the notebook “Using a “black box” likelihood function”, that you can find at this link:

But I get the error

NotImplementedError: No JAX conversion for the given `Op`: LogLikeWithGrad

I guess that the operator LogLikeWithGrad as written in the reference notebook is still not implemented.
Do you, or any member of the community, have by any chance any suggestions on how to make it work?
Thank you in advance

As that is a custom Operator created in that notebook, you would need to create a jax_funcify function to handle it yourself. There’s a guide on how to do that here: Adding JAX and Numba support for Ops — Aesara 2.5.2+4.g980c4c2c7.dirty documentation