Out of memory when "transforming variables" in Numpyro & JAX

Thank you for the recommendation. I don’t think it would fix this problem, though, because those parameters aren’t used at the step that is failing. (Of course, you know the code base far better than me, so please correct me if I’m wrong.) But I have seen this suggested for other RAM problems, so I’ll definitely keep it in mind. So far, reducing the number of Deterministic variables has helped a lot.

2 Likes