New sampler library: nutpie

I would like to announce a new sampler library that can sample from pymc and stan models using NUTS: nutpie.

At least in my benchmarks it currently is the fastest CPU based sampler implementing NUTS.

It compiles logp gradient functions for pymc models using the numba backend, and calls those functions from a rust implementation of NUTS without python overhead, and outperformed stan and pymc (and probably also numpyro, but I didn’t test this as much) in all models I tried (with some caveats mentioned below).

It uses a different method for diagonal mass matrix adaptation, that performed equally or sometimes significantly better than the methods in stan or pymc on my models: It often lowers the required number of gradient evaluations significantly. I’d be curious to see how it does on other peoples models.

Sampler statistics are available as well, even some (the currently mass matrix or the precise locations of divergences) that were not available in pymc before.

It is easy to install from conda forge (conda install -c conda-forge nutpie), and easy to use:

import nutpie

pymc_model = create_some_model()

compiled_model = nutpie.compile_pymc_model(pymc_model)
trace = nutpie.sample(compiled_model, chains=10)

# trace is now an arviz.InferenceData object

Some current caveats that hopefully get resolved soon:

  • nutpie uses the numba backend of aesara, and if a pymc model uses an Op that does not have a native implementation but falls back to objectmode, this will kill performance completely. Also, the numba backend isn’t as well tested as the old backend, so there might be bugs still hiding somewhere.
  • It can also sample from models implemented in stan, but to do so it requires a patched version of httpstan (see the nutpie readme for details).
  • A numba issue significantly hurts its performance (often more than a factor of 2). This issue is actively being worked on, and I hope there will be a fix relatively soon. (I did benchmarks comparing it to other samplers with a patch that avoids this issue).
  • Compile times for models can sometimes be pretty high, but this part isn’t optimized well yet, so I hope we can improve this significantly.
18 Likes

Adrian, I remember you were saying you are going to implement NUTS in RUST (probably already a few years back) and you finally did it!!! Huge congrats!

210

7 Likes

Got nutpie working with a Bambi hierarchical model. Faster than even numpyro and blackjax in the model I tested. This is great work. And it is not optimized yet…? Wow.

5 Likes

This is amazing. Thank you, speeds up my model drastically on CPU.

Is there an explanation of what the interval plots are?

Is there any way to hide the params#_interval_# from the trace plots?
Edit - can deselect them by:
var_names = list(model.named_vars.keys())
then pass var_names to the az.plotting function.

I got Nutpie to work with Bambi + PYMC 5.1.2, with its new “nuts_sampler” method. If anyone is interested in cloning my environment, I have attached the yml file. Just rename the txt extension to yml, and use it to create your conda environment. Just use nuts_sampler=‘nutpie’ in your Bambi .fit() or PYMC .sample() call and it should work.

bambi_nutpie.txt (8.5 KB)

4 Likes

New to nutpie. Couple of questions:

  1. How do i not include the default logodds trace during the inference step?
    2 Is there some nutpie documentation out there?

I tried your fix, and it did not work in a Windows environment, even though Nutpie does work on Windows.

Have you been successfully using Nutpie in Bambi?

Any help would be appreciated!