Using `float32` for `pymc.sampling.jax.sample_blackjax_nuts`

I found the solution. The changes to pytensor.config need to happen before the pymc import.

3 Likes