Jax sampling in Pymc4: no pm.sampling.jax module

So, I was really excited to be able to use jax in pymc4 but I am having a hard time setting it up.

I followed the instructions in Set up environment for JAX sampling with GPU support in PyMC v4 - #4 by payamphysics

Everything seems to be set up. Jax imports and reports jax.default_backend() to be ‘gpu’, [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

But my pymc4 does not seem to expose any of the jax sampling functions:

  • there is no pm.sampling_jax module
  • pm.sampling has no module called jax
  • i cannot find any sample_numpyro_nuts() function anywhere

I have pymc v 4.4 installed. Can anyone help?

1 Like

Having exactly the same problem, I’m using v.4.3.0 and installed using the instructions here: Installation — PyMC 4.4.0 documentation.

I also noticed that on google colab there doesn’t appear to be any jax functionality set up

1 Like

Got it!

you have to import the module separately:

import pymc.sampling.jax as pmjax

You get a warning, saying that this module is experimental but then you can do:

pmjax.sample_numpyro_nuts(1000)
3 Likes

Great work thanks !