Where can I set chain_method='vectorized' in PYMC 5.15.0 using blackjax

I find it in pymc 5.10.3 but not in pymc 5.15.0. And in pymc 5.10.3, it does not automatically give a warning if rhat is larger than 1.

You can set it in pm.sample via the nuts_sampler_kwargs keyword, as in pm.sample(nuts_sampler='blackjax', nuts_sampler_kwargs={'chain_method':'vectorized'})

1 Like