This may be very trivial but i have troubles installing pymc3jax, the module that allows to use jax sampler with pymc3 syntax.
I am on an Ubuntu machine with miniconda.
I tried runnig the command
pip install git+https://github.com/pymc-devs/pymc3.git@pymc3jax
But it gives me this message
Cloning https://github.com/pymc-devs/pymc3.git (to revision pymc3jax) to /tmp/pip-req-build-tabxbbvt
Running command git clone -q https://github.com/pymc-devs/pymc3.git /tmp/pip-req-build-tabxbbvt
WARNING: Did not find branch or tag 'pymc3jax', assuming revision or ref.
Running command git checkout -q pymc3jax
error: pathspec 'pymc3jax' did not match any file(s) known to git
WARNING: Discarding git+https://github.com/pymc-devs/pymc3.git@pymc3jax. Command errored out with exit status 1: git checkout -q pymc3jax Check the logs for full command output.
ERROR: Command errored out with exit status 1: git checkout -q pymc3jax Check the logs for full command output.
Can someone give me some instructions on how to install pymc3jax?
Thank you in advance!
Is there any reference that pointed you to do that?
There is no
pymc3jax module, the command above is trying to install a specific branch of the pymc repo which seems to no longer exist. If you want to use jax I’d recommend installing the pymc development version with
pip install git+https://github.com/pymc-devs/pymc.git
which integrates better with jax and can sample with both numpyro and blackjax samplers (both written in jax). The documentation for the latest version, inclusing the numpyro and blackjax sampling methods is available at Samplers — PyMC dev documentation
Yeah those are some very outdated instructions, where did you find them?
You need to install pymc 4.0.0b3 with aesara to use JAX.
Thanks for the answers!
I found the instruction here:
So would the JAX sampling work, if I install pymc v4 following the installation guide that i can find at the link
That is, if I run the commands
git clone https://github.com/pymc-devs/pymc/
conda env create -f ./conda-envs/environment-dev-py39.yml
conda activate pymc-dev-py39
pip install .
Thank again for your time. Hopefully this is helpful for other people as well.
I can confirm that the specific Linux install instructions sourcing directly from the repo will install the following (among all other dependencies - I tested these install steps on a Lubuntu Impish 21.10 laptop):
jax 0.3.2 pyhd8ed1ab_0 conda-forge
jaxlib 0.3.0 py39h3498573_3 conda-forge
aesara 2.5.1 py39h788985e_0 conda-forge
arviz 0.11.4 pyhd8ed1ab_0 conda-forge
python 3.9.10 h85951f9_2_cpython conda-forge
As @twiecki and @OriolAbril mentioned, this will permit you to use
It worked. I just had to install numpyro as well.
As I would like to fit some data with a model that cannot be expressed with the usual pymc3 operations, I wanted to use The JAX sampler to speed up the notebook “Using a “black box” likelihood function”, that you can find at this link:
But I get the error
NotImplementedError: No JAX conversion for the given `Op`: LogLikeWithGrad
I guess that the operator LogLikeWithGrad as written in the reference notebook is still not implemented.
Do you, or any member of the community, have by any chance any suggestions on how to make it work?
Thank you in advance
As that is a custom Operator created in that notebook, you would need to create a jax_funcify function to handle it yourself. There’s a guide on how to do that here: Adding JAX and Numba support for Ops — Aesara 2.5.2+4.g980c4c2c7.dirty documentation