Out of memory when "transforming variables" in Numpyro & JAX

I’ll definitely try that! Thank you!