Maybe this is relevant? Out of memory when "transforming variables" in Numpyro & JAX - #11 by wnorcbrown