Out of Memory when using pm.sampling.jax.sample_blackjax_nuts

I ended up switching to pm.sampling.jax.sample_numpyro_nuts and setting %env XLA_PYTHON_CLIENT_PREALLOCATE=false (check here). That solved the issue. I still dunno why I was not able to use the Blackjax implementation though.

1 Like