This may be very trivial but i have troubles installing pymc3jax, the module that allows to use jax sampler with pymc3 syntax.
I am on an Ubuntu machine with miniconda.
I tried runnig the command
pip install git+https://github.com/pymc-devs/pymc3.git@pymc3jax
But it gives me this message
Collecting git+https://github.com/pymc-devs/pymc3.git@pymc3jax
Cloning https://github.com/pymc-devs/pymc3.git (to revision pymc3jax) to /tmp/pip-req-build-tabxbbvt
Running command git clone -q https://github.com/pymc-devs/pymc3.git /tmp/pip-req-build-tabxbbvt
WARNING: Did not find branch or tag 'pymc3jax', assuming revision or ref.
Running command git checkout -q pymc3jax
error: pathspec 'pymc3jax' did not match any file(s) known to git
WARNING: Discarding git+https://github.com/pymc-devs/pymc3.git@pymc3jax. Command errored out with exit status 1: git checkout -q pymc3jax Check the logs for full command output.
ERROR: Command errored out with exit status 1: git checkout -q pymc3jax Check the logs for full command output.
Can someone give me some instructions on how to install pymc3jax?
Thank you in advance!
Is there any reference that pointed you to do that?
There is no pymc3jax
module, the command above is trying to install a specific branch of the pymc repo which seems to no longer exist. If you want to use jax I’d recommend installing the pymc development version with
pip install git+https://github.com/pymc-devs/pymc.git
which integrates better with jax and can sample with both numpyro and blackjax samplers (both written in jax). The documentation for the latest version, inclusing the numpyro and blackjax sampling methods is available at Samplers — PyMC dev documentation
2 Likes
Yeah those are some very outdated instructions, where did you find them?
You need to install pymc 4.0.0b3 with aesara to use JAX.
1 Like
Thanks for the answers!
I found the instruction here:
https://discourse.pymc.io/t/unable-to-get-pymc3-jax-to-run-importerror/6286
So would the JAX sampling work, if I install pymc v4 following the installation guide that i can find at the link
https://github.com/pymc-devs/pymc/wiki/Installation-Guide-(Linux)
That is, if I run the commands
git clone https://github.com/pymc-devs/pymc/
cd pymc
conda env create -f ./conda-envs/environment-dev-py39.yml
conda activate pymc-dev-py39
pip install .
Thank again for your time. Hopefully this is helpful for other people as well.
@Davide_Dal_Bosco
I can confirm that the specific Linux install instructions sourcing directly from the repo will install the following (among all other dependencies - I tested these install steps on a Lubuntu Impish 21.10 laptop):
pymc-4.0.0b3
jax 0.3.2 pyhd8ed1ab_0 conda-forge
jaxlib 0.3.0 py39h3498573_3 conda-forge
aesara 2.5.1 py39h788985e_0 conda-forge
arviz 0.11.4 pyhd8ed1ab_0 conda-forge
python 3.9.10 h85951f9_2_cpython conda-forge
As @twiecki and @OriolAbril mentioned, this will permit you to use jax
samplers.
2 Likes
It worked. I just had to install numpyro as well.
As I would like to fit some data with a model that cannot be expressed with the usual pymc3 operations, I wanted to use The JAX sampler to speed up the notebook “Using a “black box” likelihood function”, that you can find at this link:
https://docs.pymc.io/en/v3/pymc-examples/examples/case_studies/blackbox_external_likelihood.html
But I get the error
NotImplementedError: No JAX conversion for the given `Op`: LogLikeWithGrad
I guess that the operator LogLikeWithGrad as written in the reference notebook is still not implemented.
Do you, or any member of the community, have by any chance any suggestions on how to make it work?
Thank you in advance
As that is a custom Operator created in that notebook, you would need to create a jax_funcify function to handle it yourself. There’s a guide on how to do that here: Adding JAX and Numba support for Ops — Aesara 2.5.2+4.g980c4c2c7.dirty documentation
2 Likes