Blackjax/numpyro (the way pymc calls it) only return samples after all sampling is done so you won’t avoid memory problems by trying to save them to mcbackend.
Does running the PyMC sampler with compile_kwargs=(mode=”JAX”) give you any speedup?
You may need to change mp_ctx to something compatible with JAX