I am working on a ~weekly series on modern implementations of gradient based samplers. Three parts are out so far, and I am happy to discuss them here, or in the issues of the github repo that accompanies the articles (minimc).
This series is strongly influenced by PyMC3’s implementation, and I am using it as a testbed of ideas for PyMC4 and improvements to PyMC3.
I think these posts are great. From a basic perspective (like mine), they are quite informative for two reasons: the graphs are well-made and used effectively (and we should all remember how much more efficient a single good graph can be at communicating complex data, versus a million words), and your explanations are concise.
I would like to comment further and more deeply on these posts, but alas I am still a MCMC beginner; but, even as a beginner, these posts were useful in conceptualizing and elaborating the steps of advanced MC schemes.
New post here on higher order integrators: what they are, and why don’t we use them?
As a historical note, they exist in the PyMC3 codebase, implemented in theano, but we switched the integrators to numpy soon after, and Bob Carpenter of Stan said the higher order integrators never really helped in practice. I think no one has the heart to get rid of them or document how to use them and when.
I’m trying to understand how HMC works and found the implementation from scratch really useful.
What it’s missing, however, is the implementation of the negative_log_prob function that’s called in the main hmc function.
for p0 in momentum.rvs(size=size):
# Integrate over our path to get a new position and momentum
q_new, p_new = leapfrog(
samples[-1],
p0,
dVdq,
path_len=path_len,
step_size=step_size,
)
# Check Metropolis acceptance criterion
start_log_p = negative_log_prob(samples[-1]) - np.sum(momentum.logpdf(p0))
new_log_p = negative_log_prob(q_new) - np.sum(momentum.logpdf(p_new))
if np.log(np.random.rand()) < start_log_p - new_log_p:
samples.append(q_new)
else:
samples.append(np.copy(samples[-1]))
I understand (I think) that this function should be calculating the posterior log probability that gets updated as sampling progresses through the chains.
However, I’m not conceptualizing how this function would look.
A skeleton code or a link would be really helpful. Thank you!
For even more distributions, libraries like distrax, numpyro, or tensorflow_probability have extensive collections, and ways of constructing complicated densities.
Ah yeah, sorry: so MCMC is an algorithm for sampling from a(n unnormalized) log probability. We of course have efficient ways of generating samples from a normal or poisson or whatever distribution, but MCMC can still be used to sample from these.
More commonly, as you point out, MCMC is used for distributions where we do not have efficient samplers, especially in Bayesian inference, where p(theta | data) may be calculated. You could do that here if you wanted! For example, here is 1-dimensional linear regression:
# something like y = 2x + noise
x = np.arange(10)
y = 2 * x + np.random.randn(10)
# fit a model y = ax + N(0, 1), with a ~ N(0, 1)
def neg_log_prob(mu):
log_prior = jst.norm.logpdf(mu, 0., 1.)
log_likelihood = jnp.sum(jst.norm.logpdf(y, mu * x, 1.))
return -(log_prior + log_likelihood)
Thanks for the detailed example, it’s extremely helpful. I’m following everything except why your mean in the likelihood definition is mu * x? If I’m understanding the logic, shouldn’t it be 2 * (x + mu)? Probably overlooking something simple.
I think you’d need to jaxify the log probability (and negate it!) first, but definitely! I also shied away from pymc because the blog post (I think!) doesn’t cover
multiple dimensions
change of variables
both of which you might have to deal with if you use a real PPL!
This might just be a me problem, but I am getting the “connection is not private” error that says Attackers might be trying to steal your information from colindcarroll.com (for example, passwords, messages, or credit cards). Learn more.... Are these posts on HMC still active?
EDIT: Just to note, I have tried on several computers/browsers/wifi networks so its not just on one device/browser for me.