Series of posts on implementing Hamiltonian Monte Carlo

I am working on a ~weekly series on modern implementations of gradient based samplers. Three parts are out so far, and I am happy to discuss them here, or in the issues of the github repo that accompanies the articles (minimc).

This series is strongly influenced by PyMC3’s implementation, and I am using it as a testbed of ideas for PyMC4 and improvements to PyMC3.

Please let me know your thoughts!

Part I: Exercises in automatic differentiation using autograd and jax Gives a background in automatic differentiation
Part II: Hamiltonian Monte Carlo from scratch Gives the basic implementation of HMC in around 20 lines of Python
Part III: Step Size Adaptation in Hamiltonian Monte Carlo Presents improvements to the basic algorithm, giving a ~10x speedup

9 Likes

I think these posts are great. From a basic perspective (like mine), they are quite informative for two reasons: the graphs are well-made and used effectively (and we should all remember how much more efficient a single good graph can be at communicating complex data, versus a million words), and your explanations are concise.

I would like to comment further and more deeply on these posts, but alas I am still a MCMC beginner; but, even as a beginner, these posts were useful in conceptualizing and elaborating the steps of advanced MC schemes.

These graphs reminded me also of these animations, https://chi-feng.github.io/mcmc-demo/.

2 Likes

New post here on higher order integrators: what they are, and why don’t we use them?

As a historical note, they exist in the PyMC3 codebase, implemented in theano, but we switched the integrators to numpy soon after, and Bob Carpenter of Stan said the higher order integrators never really helped in practice. I think no one has the heart to get rid of them or document how to use them and when.

2 Likes

I’m trying to understand how HMC works and found the implementation from scratch really useful.

What it’s missing, however, is the implementation of the negative_log_prob function that’s called in the main hmc function.

for p0 in momentum.rvs(size=size):
        # Integrate over our path to get a new position and momentum
        q_new, p_new = leapfrog(
            samples[-1],
            p0,
            dVdq,
            path_len=path_len,
            step_size=step_size,
        )

        # Check Metropolis acceptance criterion
        start_log_p = negative_log_prob(samples[-1]) - np.sum(momentum.logpdf(p0))
        new_log_p = negative_log_prob(q_new) - np.sum(momentum.logpdf(p_new))
        if np.log(np.random.rand()) < start_log_p - new_log_p:
            samples.append(q_new)
        else:
            samples.append(np.copy(samples[-1]))

I understand (I think) that this function should be calculating the posterior log probability that gets updated as sampling progresses through the chains.

However, I’m not conceptualizing how this function would look.
A skeleton code or a link would be really helpful. Thank you!

1 Like

Hey! Glad you’re finding it useful! Two changes I’d make if I wrote this today:

  1. The neg_log_prob could maybe be better called neg_log_density.
  2. I don’t know anymore why I negated it (makes some of the math harder to follow)

Anyways! For, say, a normal distribution with mean 2 and std 1, this could be just negative_log_prob = lambda x: (x - 2)**2

More generally, you can use jax.scipy.stats which has some built-in distributions. For a normal with mean 1 and standard deviation 10,

import jax.scipy.stats as jst

negative_log_prob = lambda x: -jst.norm.logpdf(x, 1., 10.)  # N(x | 1, 10)

For even more distributions, libraries like distrax, numpyro, or tensorflow_probability have extensive collections, and ways of constructing complicated densities.

Thanks for the quick response.

But shouldn’t this be the posterior probability? I see no multiplication of the prior and the likelihood in what you cited.

My parallel is the simplified code I wrote to explain to myself how the Metropolis algo works.

It’s likely I’m missing something conceptually.

Ah yeah, sorry: so MCMC is an algorithm for sampling from a(n unnormalized) log probability. We of course have efficient ways of generating samples from a normal or poisson or whatever distribution, but MCMC can still be used to sample from these.

More commonly, as you point out, MCMC is used for distributions where we do not have efficient samplers, especially in Bayesian inference, where p(theta | data) may be calculated. You could do that here if you wanted! For example, here is 1-dimensional linear regression:

# something like y = 2x + noise
x = np.arange(10) 
y = 2 * x + np.random.randn(10)

# fit a model y = ax + N(0, 1), with a ~ N(0, 1)
def neg_log_prob(mu):
  log_prior = jst.norm.logpdf(mu, 0., 1.)
  log_likelihood = jnp.sum(jst.norm.logpdf(y,  mu * x, 1.))
  return -(log_prior + log_likelihood)

Thanks for the detailed example, it’s extremely helpful. I’m following everything except why your mean in the likelihood definition is mu * x? If I’m understanding the logic, shouldn’t it be 2 * (x + mu)? Probably overlooking something simple.

I think this is right. Our model is

\begin{align} \mu \sim& \mathcal{N}(0, 1) \\ y \sim& \mathcal{N}(\mu x, 1) \end{align}

We’d expect that after sampling this model, conditioned on the given data, the posterior distribution of \mu is centered around 2.

1 Like

And PyMC… Probability — PyMC 5.10.0 documentation :wink:

I think you’d need to jaxify the log probability (and negate it!) first, but definitely! I also shied away from pymc because the blog post (I think!) doesn’t cover

  • multiple dimensions
  • change of variables

both of which you might have to deal with if you use a real PPL!

This might just be a me problem, but I am getting the “connection is not private” error that says Attackers might be trying to steal your information from colindcarroll.com (for example, passwords, messages, or credit cards). Learn more.... Are these posts on HMC still active?

EDIT: Just to note, I have tried on several computers/browsers/wifi networks so its not just on one device/browser for me.

1 Like

CC @colcarroll

1 Like

Oops. Need to renew my certificate! Will get it tonight - thanks for letting me know!

[update 14 hours later: it is done.]

1 Like

Thanks!

1 Like