Hey! Glad you’re finding it useful! Two changes I’d make if I wrote this today:
- The
neg_log_probcould maybe be better calledneg_log_density. - I don’t know anymore why I negated it (makes some of the math harder to follow)
Anyways! For, say, a normal distribution with mean 2 and std 1, this could be just negative_log_prob = lambda x: (x - 2)**2
More generally, you can use jax.scipy.stats which has some built-in distributions. For a normal with mean 1 and standard deviation 10,
import jax.scipy.stats as jst
negative_log_prob = lambda x: -jst.norm.logpdf(x, 1., 10.) # N(x | 1, 10)
For even more distributions, libraries like distrax, numpyro, or tensorflow_probability have extensive collections, and ways of constructing complicated densities.