Out of memory when "transforming variables" in Numpyro & JAX

@ferrine made very little difference for me.

i’m limited to running very small chains.

an option to use jax.lax.map — JAX documentation should save memory usage and stop the memory spike