Out of memory when "transforming variables" in Numpyro & JAX

This is really nice!! How is the speed compare to vmap?

Would you like to send a PR to PyMC?

1 Like