Update: from my testing, it is happening during the vmap step:
result = jax.vmap(jax.vmap(jax_fn))(
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
)
where postprocessing_backend = "cpu". Is there some way to make this process require less RAM?