I find it in pymc 5.10.3 but not in pymc 5.15.0. And in pymc 5.10.3, it does not automatically give a warning if rhat is larger than 1.
You can set it in pm.sample
via the nuts_sampler_kwargs
keyword, as in pm.sample(nuts_sampler='blackjax', nuts_sampler_kwargs={'chain_method':'vectorized'})
1 Like