Everything seems to be set up. Jax imports and reports jax.default_backend() to be ‘gpu’, [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
But my pymc4 does not seem to expose any of the jax sampling functions:
there is no pm.sampling_jax module
pm.sampling has no module called jax
i cannot find any sample_numpyro_nuts() function anywhere