Batch process capability for pymc.sampling_jax.sample_numpyro_nuts() with GPU?

That’s more of a JAX/Numpyro question so you may have more luck asking in their forums.