Out of memory when "transforming variables" in Numpyro & JAX

As far as I could tell the speed was the same.

Sure :+1:

1 Like