Batch process capability for pymc.sampling_jax.sample_numpyro_nuts() with GPU?

Yout can try setting chain_method to "sequential" or "vectorized" and see if it does better.

It is normal for less RAM to be needed for tuning than drawing, because tuning samples can be discarded over time.