I ended up switching to pm.sampling.jax.sample_numpyro_nuts
and setting %env XLA_PYTHON_CLIENT_PREALLOCATE=false
(check here). That solved the issue. I still dunno why I was not able to use the Blackjax implementation though.
1 Like