How to set up a pymc environment on google cloud compute platform?

Well, I thought it worked great. It seems I’m getting the following error:

---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
/tmp/ipykernel_2309/275736203.py in <module>
      1 import pymc3 as pm
----> 2 import pymc3.sampling_jax
      3 
      4 import aesara
      5 from aesara import tensor as at

/opt/conda/lib/python3.7/site-packages/pymc3/sampling_jax.py in <module>
     16 import theano.graph.fg
     17 
---> 18 from theano.link.jax.jax_dispatch import jax_funcify
     19 
     20 import pymc3 as pm

/opt/conda/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in <module>
     84 # Older versions < 0.2.0 do not have this flag so we don't need to set it.
     85 try:
---> 86     jax.config.disable_omnistaging()
     87 except AttributeError:
     88     pass

/opt/conda/lib/python3.7/site-packages/jax/_src/config.py in disable_omnistaging(self)
    183   def disable_omnistaging(self):
    184     raise Exception(
--> 185       "Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: "
    186       "see https://github.com/google/jax/blob/main/design_notes/omnistaging.md.")
    187 

Exception: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/main/design_notes/omnistaging.md.

That stated, I getting confused on which installation instructions to use. The time series data I’ll be working is pretty big, which is why I was interested in trying the jax sampler. Will you point me to the right installation instructions to use?