Isolate NUTS into a new library

The main work horse of pymc3 and many more modern Bayesian library is the no-u-turn sampler. Internally, pymc3 build the probabilistic model (everything in the with pm.Model() ... context) in theano by compiling a callable that you can pass a 1d array (all the unknown/free parameters) and get back a scalar log_prob and the gradient (a 1d array). Then NUTS (or any other sampler) takes callable and the input and do the magic in numpy + multiprocess.

It would be great to isolate this part of the code into a standalone library, it will help us testing and benchmarking.


Is it just for information, or is it a call for volunteers?

For context, this is to follow up discussion on Twitter:

Ooh ok, interesting! I’d be delighted to join the project if you think I’ve got the required skills and would provide useful help!

I’m here for this! I would definitely make use of something like this if it existed and I’m happy to put in legwork to make it happen.

@twiecki mentioned that we might be able to start with the version in the pymc4 codebase, but I think that there are a few open questions.

The main (big) question that I have: are there design decisions that would be good to update while doing a re-write like this? For example, we might want to reconsider the interface currently defined for tuning the parameters of the quadpotential objects (right now they’re updated every step).

There are also questions about where to draw the split: how much of the infrastructure would live in which library? And what the interface would look like? I imagine a class with something logp, logp_and_grad, and deterministics methods, but I’m sure that the core devs will have a better sense. I’d want to make sure that this interface would be flexible enough to support (at least in principle) models defined in non-pymc frameworks.

No matter what, I put in my +1 for this idea!

1 Like

These are all great questions. I think the easiest approach is not to touch the pymc3 (i.e., pymc3 will not depends on it), but isolate the hmc part (what @twiecki have done, but add multiprocessing) and add wrapper and helper function around it. Visualization etc are all depending on Arviz. Essentially it will be a bare bone high performative NUTS that takes callable as input. In that case, I think the tensorflow_probability sample_chain API is pretty close to what I have in mind.

1 Like

I think I’ve made some progress over here:

It’s still obviously a WIP, but I just cut the v0.1.0 release, and you can pip install littlemcmc to try it out. Hopefully there aren’t too many bugs! I’d also really appreciate any feedback anyone has :slight_smile:

cc @junpenglao @dfm