So, I was really excited to be able to use jax in pymc4 but I am having a hard time setting it up.
I followed the instructions in Set up environment for JAX sampling with GPU support in PyMC v4 - #4 by payamphysics
Everything seems to be set up. Jax imports and reports jax.default_backend() to be ‘gpu’, [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
But my pymc4 does not seem to expose any of the jax sampling functions:
- there is no pm.sampling_jax module
- pm.sampling has no module called jax
- i cannot find any sample_numpyro_nuts() function anywhere
I have pymc v 4.4 installed. Can anyone help?