Batch process capability for pymc.sampling_jax.sample_numpyro_nuts() with GPU?

Thanks a lot for your reply!

I’ve tried that but it didn’t seem to improve much. I’m running more tunes in an attempt to reach the stable posterior before drawing and it seems to improve the accuracy when modelling with my synthetic data.

Are there any other ways to deal with the OOM issue with GPU?