Pymc, numpyro GPs and Transform RVs memory behaviour

Maybe this is relevant? Out of memory when "transforming variables" in Numpyro & JAX - #11 by wnorcbrown