PyMC 6.0 Nutpie vs Numpyro for a GP-heavy model

Hey all!

With the changes in PyMC 6.0 that sets nutpie as the default NUTS backend, with native JAX integration, what’s the best port of call for sampling GP-heavy likelihoods? Previously, the wisdom has been to use Numpyro as the sampling backend. Is this still the case?

The normalizing flow integration looks cool, so it would be good to try it out if there is negligble time loss moving from Numpyro to Nutpie :slight_smile:

Cheers!
Oli

Nutpie can compute use either numba or jax via the nuts argument, so you can try both and see what you think:

idata = pm.sample(nuts={'backend':'numba'})
idata = pm.sample(nuts={'backend':'jax'})
idata = pm.sample(nuts={'backend':'jax', 'gradient_backend':'jax'})

The gradient backend is just for autodiff, the default is pytensor but you can also choose to do it in jax. In any case the whole logp_dlog function ends up as jax. We should be better in most cases but if you find one where we aren’t I’d like to hear about it.

Keep in mind that a lot of the nutpie advantage comes from more efficient tuning, so you should be able to get away with many fewer tune steps than numpyro, even if the per-step eval speed is about the same.

1 Like

Thanks Jesse!

I should assume that my current use of pm.sample(nuts_sampler=’numpyro’) in pymc5 is the equivalent of pm.sample(nuts={'backend':'jax', 'gradient_backend':'jax'}), right?

In PyMC 6, pm.sample(nuts={'backend':'jax', 'gradient_backend':'jax'}) will sample with nutpie (if it’s installed), which is the recommended, otherwise pymc python NUTS sampler.

If you want to stick with numpyro just do as before nuts_sampler="numpyro", although we recommend nutpie even for jax/GPU.

More info PyMC 6.0 & PyTensor 3.0: ecosystem updates — PyMC project website

Perfect and clear, thanks everyone! I’ll get to work :flexed_biceps:

Those backend arguments aren’t what triggers nutpie though right? We always default to nutpie, but in numba mode. Passing backend='jax' will give you jax mode, which is always sufficient. The extra gradient_backend='jax' argument is highly optional.

Correct, although they can untrigger if they are not compatible. To be sure to get or fail to get nutpie explicitly you can pass nuts_sampler="nutpie".