Sampling_jax issues

Hi!
Is anyone else having issues with this sampler when running on jupyter notebook?

I spent some time debugging in spyder, then came back to the notebooks and when running

import pymc as pm
import pymc.sampling_jax

which usually gave me an ‘experimental’ warning, I now get that the module is not in the package (full error below)

ModuleNotFoundError                       Traceback (most recent call last)
/tmp/ipykernel_136021/3747734430.py in <cell line: 1>()
----> 1 import pymc.sampling_jax

~/code/miniforge3/mambaforge/envs/spyder-env/lib/python3.10/site-packages/pymc/sampling_jax.py in <module>
      5 # pylint: disable=unused-wildcard-import
      6 
----> 7 from pymc.sampling.jax import *

~/code/miniforge3/mambaforge/envs/spyder-env/lib/python3.10/site-packages/pymc/sampling/jax.py in <module>
     17 
     18 import arviz as az
---> 19 import jax
     20 import numpy as np
     21 import pytensor.tensor as at

~/doc/jax.py in <module>
     22 
     23 from arviz.data.base import make_attrs
---> 24 from jax.experimental.maps import SerialLoop, xmap
     25 from pytensor.compile import SharedVariable, Supervisor, mode
     26 from pytensor.graph.basic import graph_inputs

ModuleNotFoundError: No module named 'jax.experimental'; 'jax' is not a package

I’ve already tried with jax and jaxlib installed via pip and conda, as well as ensuring both of them are at the same version (0.4.12, with conda forge, I get jax 0.4.13 and jaxlib 0.4.12)

Rigth now, I have the following versions:
pymc 5.0.1
jax 0.4.12
jaxlib 0.4.12
numpyro 0.12.1

I used the same environment in the same machine to run the same code on spyder yesterday and it worked fine.
Thanks in advance!

Here are my imports and sample call when I used jax or numpyro in a jupyter notebook

import pymc as pm
import pymc.sampling_jax
import numpyro
import blackjax
import jax
with mdl:
    idata = pm.sampling.jax.sample_numpyro_nuts(postprocessing_backend='cpu',
                                                                       idata_kwargs=dict(log_likelihood=False))

Should also with with pm.sampling.jax.sample_blackjax_nuts

You might be accessing a different environment on spyder and on the notebook. You can print sys.executable to check if they match.

On another note your version of PyMC is quite old, I suggest you upgrade. In that case you can reach jax sampling via pm.sample(nuts_sampler="numpyro")

1 Like

Was unaware of the nuts_sampler="numpyro" syntax, much appreciated :+1:

1 Like

Hi Ricardo! thank you for your answer.
Yes, I saw there were newer versions but wanted to check with current installations just in case.
They were, in fact, the same environment.

I’ve since updated to pymc = 5.6.0 and still get the same error when doing import pymc.sampling_jax, and if I run pm.sample(nuts_sampler="numpyro").

EDIT: I’ve actually already tried setting up an environment from scratch, having
matplotlib 3.7.2
numpy 1.25.1
pandas 2.0.3
arviz 0.15.1
pandas 2.0.3
pymc 5.6.0
jax 0.4.13
jaxlib 0.4.12
numpyro 0.12.1
I’m attaching its yml in txt format just in case

pymc_env.txt (10.5 KB)

Can you import jax in that notebook?

oops, no, I get the same error messageModuleNotFoundError: No module named 'jax.experimental'; 'jax' is not a package.

This means I should go straight to Jax support, right?

EDIT: found another jax.py in the PATH that was interfering, sorry for the noise!
I’ll mark it as solved.
Thank you as always Ricardo!
I will try the new method you mentioned of using nuts_sampler="numpyro"

1 Like