I tried to use the Numpyro/JAX sampler in a default Google Colab instance, after updating PyMC3 to v.3.11.2. The line import pymc3.sampling_jax
results in the following error:
Traceback (most recent call last)
<ipython-input-51-04ca4ab85d88> in <module>()
----> 1 import pymc3.sampling_jax
2 with model:
3 trace = pm.sampling_jax.sample_numpyro_nuts()
/usr/local/lib/python3.7/dist-packages/pymc3/sampling_jax.py in <module>()
16 import theano.graph.fg
---> 18 from theano.link.jax.jax_dispatch import jax_funcify
20 import pymc3 as pm
/usr/local/lib/python3.7/dist-packages/theano/link/jax/jax_dispatch.py in <module>()
84 # Older versions < 0.2.0 do not have this flag so we don't need to set it.
85 try:
---> 86 jax.config.disable_omnistaging()
87 except AttributeError:
88 pass
/usr/local/lib/python3.7/dist-packages/jax/config.py in disable_omnistaging(self)
166 def disable_omnistaging(self):
167 raise Exception(
--> 168 "Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: "
169 "see https://github.com/google/jax/blob/master/design_notes/omnistaging.md.")
Exception: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/master/design_notes/omnistaging.md.
It seems to be an issue regarding the JAX version (I am using JAX 0.2.12). How can I solve this problem and use the new JAX sampler?
Thanks in advance.