I’m excited to introduce an experimental implementation of Fisher-Adapted Hamiltonian Monte Carlo (HMC) using normalizing flows to replace the traditional mass matrix. This new method can increase sampling efficiency, especially for models with complex posteriors that struggle with standard HMC (like those with funnels or difficult parameterizations).
It works with both PyMC and Stan models, and shouldn’t be too hard to try out. (famous last words…)
Current Limitations:
Computational Cost: The optimization process can be slow, but the additional computational cost is less significant for slow log-likelihood evaluations. (Don’t expect any model to run faster than ~10 minutes). I hope a lot of that cost will go away when the implementation get’s better.
Parameter Tuning: You might have to tune some parameters manually. That’s mainly nn_width
(the width of the neural network layer, and num_layers
, the number of normalizing flow layers. Larger values may improve expressiveness at the cost of optimization difficulty.
Scaling: So far it doesn’t work great with models with more than about 1000 parameters. Hopefully that will change in the future as well.
Installation and Setup:
To try out the new feature, follow the installation steps (using pixi):
git clone https://github.com/pymc-devs/nutpie
cd nutpie
git fetch origin pull/154/head:transform
git switch transform
pixi run develop
pixi shell
This will give you a shell with a hopefully correctly set up python environment. You can also use pixi-kernel to connect it to jupyter.
Usage:
import pymc as pm
import nutpie
import numpy as np
import jax
jax.config.update("jax_enable_x64", True)
with pm.Model() as model:
log_sd = pm.Normal("log_sd")
pm.Normal("y", sigma=np.exp(log_sd))
compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax")
compiled = (
compiled
.with_transform_adapt(
# Neural network width, default is half the number of model parameters
nn_width=None,
# Number of normalizing flow layers
num_layers=8,
# Depth of the neural network in each flow layer
nn_depth=1,
# Print status update of the optimizer.
verbose=False,
# Number of gradients to use in each training phase
window_size=5000,
# Learning rate of the optimizer
learning_rate=1e-3,
# Print progress bars for the optimization. Very spammy...
show_progress=False,
# Number of initial windows with a diagonal mass matrix
num_diag_windows=10,
)
)
trace = nutpie.sample(
compiled,
transform_adapt=True,
chains=2,
tune=1000,
draws=1000,
cores=1,
seed=123,
)
Or with stan models:
import pymc as pm
import nutpie
import numpy as np
import jax
import os
os.environ["TBB_CXX_TYPE"] = "clang"
jax.config.update("jax_enable_x64", True)
code = """
parameters {
real log_sigma;
real x;
}
model {
log_sigma ~ normal(0, 1);
x ~ normal(0, exp(log_sigma));
}
"""
compiled = nutpie.compile_stan_model(code=code)
compiled = (
compiled
.with_transform_adapt(
# Neural network width, default is half the number of model parameters
nn_width=None,
# Number of normalizing flow layers
num_layers=8,
# Depth of the neural network in each flow layer
nn_depth=1,
# Print status update of the optimizer.
verbose=False,
# Number of gradients to use in each training phase
window_size=5000,
# Learning rate of the optimizer
learning_rate=1e-3,
# Print progress bars for the optimization. Very spammy...
show_progress=False,
# Number of initial windows with a diagonal mass matrix
num_diag_windows=10,
)
)
trace = nutpie.sample(
compiled,
transform_adapt=True,
chains=2,
tune=1000,
draws=1000,
cores=1,
seed=123,
)
GPU support is very helpful. JAX should automatically pick up a CUDA device if available, and that should speed up the optimization a lot, even if the logp function evaluation doesn’t work too great on the GPU.
I’m eager to hear how this performs with different models! Please give it a try and let us know how it works for you!