Well, I thought it worked great. It seems I’m getting the following error:
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
/tmp/ipykernel_2309/275736203.py in <module>
1 import pymc3 as pm
----> 2 import pymc3.sampling_jax
3
4 import aesara
5 from aesara import tensor as at
/opt/conda/lib/python3.7/site-packages/pymc3/sampling_jax.py in <module>
16 import theano.graph.fg
17
---> 18 from theano.link.jax.jax_dispatch import jax_funcify
19
20 import pymc3 as pm
/opt/conda/lib/python3.7/site-packages/theano/link/jax/jax_dispatch.py in <module>
84 # Older versions < 0.2.0 do not have this flag so we don't need to set it.
85 try:
---> 86 jax.config.disable_omnistaging()
87 except AttributeError:
88 pass
/opt/conda/lib/python3.7/site-packages/jax/_src/config.py in disable_omnistaging(self)
183 def disable_omnistaging(self):
184 raise Exception(
--> 185 "Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: "
186 "see https://github.com/google/jax/blob/main/design_notes/omnistaging.md.")
187
Exception: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/main/design_notes/omnistaging.md.
That stated, I getting confused on which installation instructions to use. The time series data I’ll be working is pretty big, which is why I was interested in trying the jax sampler. Will you point me to the right installation instructions to use?