Omnistaging issue with JAX sampling

I tried to use the Numpyro/JAX sampler in a default Google Colab instance, after updating PyMC3 to v.3.11.2. The line import pymc3.sampling_jax results in the following error:

Traceback (most recent call last)
<ipython-input-51-04ca4ab85d88> in <module>()
----> 1 import pymc3.sampling_jax
      2 with model:
      3   trace = pm.sampling_jax.sample_numpyro_nuts()

/usr/local/lib/python3.7/dist-packages/pymc3/sampling_jax.py in <module>()
     16 import theano.graph.fg
     17 
---> 18 from theano.link.jax.jax_dispatch import jax_funcify
     19 
     20 import pymc3 as pm

/usr/local/lib/python3.7/dist-packages/theano/link/jax/jax_dispatch.py in <module>()
     84 # Older versions < 0.2.0 do not have this flag so we don't need to set it.
     85 try:
---> 86     jax.config.disable_omnistaging()
     87 except AttributeError:
     88     pass

/usr/local/lib/python3.7/dist-packages/jax/config.py in disable_omnistaging(self)
    166   def disable_omnistaging(self):
    167     raise Exception(
--> 168       "Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: "
    169       "see https://github.com/google/jax/blob/master/design_notes/omnistaging.md.")
    170 

Exception: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/master/design_notes/omnistaging.md.

It seems to be an issue regarding the JAX version (I am using JAX 0.2.12). How can I solve this problem and use the new JAX sampler?

Thanks in advance.

FWIW, I ran into this too. Opened an issue here: sampling_jax import issues · Issue #4645 · pymc-devs/pymc3 · GitHub.

1 Like

I’m also curious whether a solution is in the works. For now, my workaround is to use JAX 0.2.11, i.e. one version below the current one, where disabling omnistaging is still possible.

1 Like