I tried to use the Numpyro/JAX sampler in a default Google Colab instance, after updating PyMC3 to v.3.11.2. The line
import pymc3.sampling_jax results in the following error:
Traceback (most recent call last) <ipython-input-51-04ca4ab85d88> in <module>() ----> 1 import pymc3.sampling_jax 2 with model: 3 trace = pm.sampling_jax.sample_numpyro_nuts() /usr/local/lib/python3.7/dist-packages/pymc3/sampling_jax.py in <module>() 16 import theano.graph.fg 17 ---> 18 from theano.link.jax.jax_dispatch import jax_funcify 19 20 import pymc3 as pm /usr/local/lib/python3.7/dist-packages/theano/link/jax/jax_dispatch.py in <module>() 84 # Older versions < 0.2.0 do not have this flag so we don't need to set it. 85 try: ---> 86 jax.config.disable_omnistaging() 87 except AttributeError: 88 pass /usr/local/lib/python3.7/dist-packages/jax/config.py in disable_omnistaging(self) 166 def disable_omnistaging(self): 167 raise Exception( --> 168 "Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: " 169 "see https://github.com/google/jax/blob/master/design_notes/omnistaging.md.") 170 Exception: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/master/design_notes/omnistaging.md.
It seems to be an issue regarding the JAX version (I am using JAX 0.2.12). How can I solve this problem and use the new JAX sampler?
Thanks in advance.