New experimental sampling algorithm: Fisher HMC in nutpie for pymc and stan models

I’m excited to introduce an experimental implementation of Fisher-Adapted Hamiltonian Monte Carlo (HMC) using normalizing flows to replace the traditional mass matrix. This new method can increase sampling efficiency, especially for models with complex posteriors that struggle with standard HMC (like those with funnels or difficult parameterizations).

It works with both PyMC and Stan models, and shouldn’t be too hard to try out. (famous last words…)

Current Limitations:

Computational Cost: The optimization process can be slow, but the additional computational cost is less significant for slow log-likelihood evaluations. (Don’t expect any model to run faster than ~10 minutes). I hope a lot of that cost will go away when the implementation get’s better.

Parameter Tuning: You might have to tune some parameters manually. That’s mainly nn_width (the width of the neural network layer, and num_layers, the number of normalizing flow layers. Larger values may improve expressiveness at the cost of optimization difficulty.

Scaling: So far it doesn’t work great with models with more than about 1000 parameters. Hopefully that will change in the future as well.

Installation and Setup:

To try out the new feature, follow the installation steps (using pixi):

git clone https://github.com/pymc-devs/nutpie
cd nutpie
git fetch origin pull/154/head:transform
git switch transform
pixi run develop
pixi shell

This will give you a shell with a hopefully correctly set up python environment. You can also use pixi-kernel to connect it to jupyter.

Usage:

import pymc as pm
import nutpie
import numpy as np
import jax

jax.config.update("jax_enable_x64", True)

with pm.Model() as model:
    log_sd = pm.Normal("log_sd")
    pm.Normal("y", sigma=np.exp(log_sd))

compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax")

compiled = (
    compiled
    .with_transform_adapt(
        # Neural network width, default is half the number of model parameters
        nn_width=None,
        # Number of normalizing flow layers
        num_layers=8,
        # Depth of the neural network in each flow layer
        nn_depth=1,
        # Print status update of the optimizer.
        verbose=False,
        # Number of gradients to use in each training phase
        window_size=5000,
        # Learning rate of the optimizer
        learning_rate=1e-3,
        # Print progress bars for the optimization. Very spammy...
        show_progress=False,
        # Number of initial windows with a diagonal mass matrix
        num_diag_windows=10,
    )
)

trace = nutpie.sample(
    compiled,
    transform_adapt=True,
    chains=2,
    tune=1000,
    draws=1000,
    cores=1,
    seed=123,
)

Or with stan models:

import pymc as pm
import nutpie
import numpy as np
import jax
import os

os.environ["TBB_CXX_TYPE"] = "clang"
jax.config.update("jax_enable_x64", True)

code = """
parameters {
    real log_sigma;
    real x;
}
model {
    log_sigma ~ normal(0, 1);
    x ~ normal(0, exp(log_sigma));
}
"""


compiled = nutpie.compile_stan_model(code=code)

compiled = (
    compiled
    .with_transform_adapt(
        # Neural network width, default is half the number of model parameters
        nn_width=None,
        # Number of normalizing flow layers
        num_layers=8,
        # Depth of the neural network in each flow layer
        nn_depth=1,
        # Print status update of the optimizer.
        verbose=False,
        # Number of gradients to use in each training phase
        window_size=5000,
        # Learning rate of the optimizer
        learning_rate=1e-3,
        # Print progress bars for the optimization. Very spammy...
        show_progress=False,
        # Number of initial windows with a diagonal mass matrix
        num_diag_windows=10,
    )
)

trace = nutpie.sample(
    compiled,
    transform_adapt=True,
    chains=2,
    tune=1000,
    draws=1000,
    cores=1,
    seed=123,
)

GPU support is very helpful. JAX should automatically pick up a CUDA device if available, and that should speed up the optimization a lot, even if the logp function evaluation doesn’t work too great on the GPU.

I’m eager to hear how this performs with different models! Please give it a try and let us know how it works for you!

9 Likes

@aseyboldt don’t you want to share it in the Stan discourse as well?

Thanks a lot for implementing and making this available @aseyboldt . Could you share a publication that explains the main ideas of the algorithm for reference?

@ricardoV94 I’ll post on the stan discorse soon.

@Benjamin I was afraid someone might ask…
The best I can currently offer is the a WIP paper: https://github.com/aseyboldt/covadapt-paper/releases/download/latest/main.pdf

4 Likes

Thanks a lot @aseyboldt for sharing the WIP. Very nice read already, looking forward to trying this out.

I linked to here from a new Stan discourse topic:

2 Likes