Out of memory when "transforming variables" in Numpyro & JAX

The only way to use less RAM as I know so far is to reduce the amount of deterministic variables