Series of posts on implementing Hamiltonian Monte Carlo

Hey! Glad you’re finding it useful! Two changes I’d make if I wrote this today:

  1. The neg_log_prob could maybe be better called neg_log_density.
  2. I don’t know anymore why I negated it (makes some of the math harder to follow)

Anyways! For, say, a normal distribution with mean 2 and std 1, this could be just negative_log_prob = lambda x: (x - 2)**2

More generally, you can use jax.scipy.stats which has some built-in distributions. For a normal with mean 1 and standard deviation 10,

import jax.scipy.stats as jst

negative_log_prob = lambda x: -jst.norm.logpdf(x, 1., 10.)  # N(x | 1, 10)

For even more distributions, libraries like distrax, numpyro, or tensorflow_probability have extensive collections, and ways of constructing complicated densities.