Out of memory when "transforming variables" in Numpyro & JAX

in fact this seems to work:

from jax.experimental.maps import SerialLoop, xmap
num_chunks = 10 # must be a multiple of number of samples
mapper = xmap(jax_fn, 
              in_axes=['chain', 'samples', ...], 
              out_axes=["chain", 'samples', ...], 
              axis_resources={"samples": SerialLoop(num_chunks)})
loop = xmap(mapper, 
            in_axes=[...], 
            out_axes=[...])
result = loop(*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]))

Massively decreases memory usage and returns the same result as vmap.

2 Likes