I would like to announce a new sampler library that can sample from pymc and stan models using NUTS: nutpie.
At least in my benchmarks it currently is the fastest CPU based sampler implementing NUTS.
It compiles logp gradient functions for pymc models using the numba backend, and calls those functions from a rust implementation of NUTS without python overhead, and outperformed stan and pymc (and probably also numpyro, but I didn’t test this as much) in all models I tried (with some caveats mentioned below).
It uses a different method for diagonal mass matrix adaptation, that performed equally or sometimes significantly better than the methods in stan or pymc on my models: It often lowers the required number of gradient evaluations significantly. I’d be curious to see how it does on other peoples models.
Sampler statistics are available as well, even some (the currently mass matrix or the precise locations of divergences) that were not available in pymc before.
It is easy to install from conda forge (
conda install -c conda-forge nutpie), and easy to use:
import nutpie pymc_model = create_some_model() compiled_model = nutpie.compile_pymc_model(pymc_model) trace = nutpie.sample(compiled_model, chains=10) # trace is now an arviz.InferenceData object
Some current caveats that hopefully get resolved soon:
- nutpie uses the numba backend of aesara, and if a pymc model uses an Op that does not have a native implementation but falls back to objectmode, this will kill performance completely. Also, the numba backend isn’t as well tested as the old backend, so there might be bugs still hiding somewhere.
- It can also sample from models implemented in stan, but to do so it requires a patched version of httpstan (see the nutpie readme for details).
- A numba issue significantly hurts its performance (often more than a factor of 2). This issue is actively being worked on, and I hope there will be a fix relatively soon. (I did benchmarks comparing it to other samplers with a patch that avoids this issue).
- Compile times for models can sometimes be pretty high, but this part isn’t optimized well yet, so I hope we can improve this significantly.