Out of memory when "transforming variables" in Numpyro & JAX

CC @ferrine