in fact this seems to work:
from jax.experimental.maps import SerialLoop, xmap
num_chunks = 10 # must be a multiple of number of samples
mapper = xmap(jax_fn,
in_axes=['chain', 'samples', ...],
out_axes=["chain", 'samples', ...],
axis_resources={"samples": SerialLoop(num_chunks)})
loop = xmap(mapper,
in_axes=[...],
out_axes=[...])
result = loop(*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]))
Massively decreases memory usage and returns the same result as vmap.