Sampling_jax issues

Hi Ricardo! thank you for your answer.
Yes, I saw there were newer versions but wanted to check with current installations just in case.
They were, in fact, the same environment.

I’ve since updated to pymc = 5.6.0 and still get the same error when doing import pymc.sampling_jax, and if I run pm.sample(nuts_sampler="numpyro").

EDIT: I’ve actually already tried setting up an environment from scratch, having
matplotlib 3.7.2
numpy 1.25.1
pandas 2.0.3
arviz 0.15.1
pandas 2.0.3
pymc 5.6.0
jax 0.4.13
jaxlib 0.4.12
numpyro 0.12.1
I’m attaching its yml in txt format just in case

pymc_env.txt (10.5 KB)