Jax sampling in Pymc4: no pm.sampling.jax module

Having exactly the same problem, I’m using v.4.3.0 and installed using the instructions here: Installation — PyMC 4.4.0 documentation.

I also noticed that on google colab there doesn’t appear to be any jax functionality set up

1 Like