Out of memory when "transforming variables" in Numpyro & JAX

Update: from my testing, it is happening during the vmap step:

result = jax.vmap(jax.vmap(jax_fn))(
    *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
)

where postprocessing_backend = "cpu". Is there some way to make this process require less RAM?