Hi!
Is anyone else having issues with this sampler when running on jupyter notebook?
I spent some time debugging in spyder, then came back to the notebooks and when running
import pymc as pm
import pymc.sampling_jax
which usually gave me an ‘experimental’ warning, I now get that the module is not in the package (full error below)
ModuleNotFoundError Traceback (most recent call last)
/tmp/ipykernel_136021/3747734430.py in <cell line: 1>()
----> 1 import pymc.sampling_jax
~/code/miniforge3/mambaforge/envs/spyder-env/lib/python3.10/site-packages/pymc/sampling_jax.py in <module>
5 # pylint: disable=unused-wildcard-import
6
----> 7 from pymc.sampling.jax import *
~/code/miniforge3/mambaforge/envs/spyder-env/lib/python3.10/site-packages/pymc/sampling/jax.py in <module>
17
18 import arviz as az
---> 19 import jax
20 import numpy as np
21 import pytensor.tensor as at
~/doc/jax.py in <module>
22
23 from arviz.data.base import make_attrs
---> 24 from jax.experimental.maps import SerialLoop, xmap
25 from pytensor.compile import SharedVariable, Supervisor, mode
26 from pytensor.graph.basic import graph_inputs
ModuleNotFoundError: No module named 'jax.experimental'; 'jax' is not a package
I’ve already tried with jax and jaxlib installed via pip and conda, as well as ensuring both of them are at the same version (0.4.12, with conda forge, I get jax 0.4.13 and jaxlib 0.4.12)
Rigth now, I have the following versions:
pymc 5.0.1
jax 0.4.12
jaxlib 0.4.12
numpyro 0.12.1
I used the same environment in the same machine to run the same code on spyder yesterday and it worked fine.
Thanks in advance!