Is anyone else having issues with this sampler when running on jupyter notebook?
I spent some time debugging in spyder, then came back to the notebooks and when running
import pymc as pm
which usually gave me an ‘experimental’ warning, I now get that the module is not in the package (full error below)
ModuleNotFoundError Traceback (most recent call last)
/tmp/ipykernel_136021/3747734430.py in <cell line: 1>()
----> 1 import pymc.sampling_jax
~/code/miniforge3/mambaforge/envs/spyder-env/lib/python3.10/site-packages/pymc/sampling_jax.py in <module>
5 # pylint: disable=unused-wildcard-import
----> 7 from pymc.sampling.jax import *
~/code/miniforge3/mambaforge/envs/spyder-env/lib/python3.10/site-packages/pymc/sampling/jax.py in <module>
18 import arviz as az
---> 19 import jax
20 import numpy as np
21 import pytensor.tensor as at
~/doc/jax.py in <module>
23 from arviz.data.base import make_attrs
---> 24 from jax.experimental.maps import SerialLoop, xmap
25 from pytensor.compile import SharedVariable, Supervisor, mode
26 from pytensor.graph.basic import graph_inputs
ModuleNotFoundError: No module named 'jax.experimental'; 'jax' is not a package
I’ve already tried with jax and jaxlib installed via pip and conda, as well as ensuring both of them are at the same version (0.4.12, with conda forge, I get jax 0.4.13 and jaxlib 0.4.12)
Rigth now, I have the following versions:
I used the same environment in the same machine to run the same code on spyder yesterday and it worked fine.
Thanks in advance!
Here are my imports and sample call when I used jax or numpyro in a jupyter notebook
import pymc as pm
idata = pm.sampling.jax.sample_numpyro_nuts(postprocessing_backend='cpu',
Should also with with
You might be accessing a different environment on spyder and on the notebook. You can print
sys.executable to check if they match.
On another note your version of PyMC is quite old, I suggest you upgrade. In that case you can reach jax sampling via
Was unaware of the
nuts_sampler="numpyro" syntax, much appreciated
Hi Ricardo! thank you for your answer.
Yes, I saw there were newer versions but wanted to check with current installations just in case.
They were, in fact, the same environment.
I’ve since updated to
pymc = 5.6.0 and still get the same error when doing
import pymc.sampling_jax, and if I run
EDIT: I’ve actually already tried setting up an environment from scratch, having
I’m attaching its yml in txt format just in case
pymc_env.txt (10.5 KB)
import jax in that notebook?
oops, no, I get the same error message
ModuleNotFoundError: No module named 'jax.experimental'; 'jax' is not a package.
This means I should go straight to Jax support, right?
EDIT: found another jax.py in the PATH that was interfering, sorry for the noise!
I’ll mark it as solved.
Thank you as always Ricardo!
I will try the new method you mentioned of using