Error when using `sample_blackjax_nuts` with `pytensor.scan`

I see - so I assume it is not straightfoward to implement. Are there alternatives to scan? E.g., would rewriting the model as an explicit loop and letting jax optimize the loop work?