Help navigating implementation of new inference algorithm

Hi, we recently submitted an article with this new inference algorithm called the Marginal Unbiased Score Expansion (MUSE). In short, its a VI competitor, and could kind of be considered a type of SBI. The initial code is in Julia and plugs into Turing.jl, one of Julia’s PPLs (example of how it looks and my post on the Julia forums).

I’d like to write a PyMC interface as well to reach a broader audience. I’ve never used PyMC before but have been familiarizing myself with it the last few days. I was wondering if someone could offer advice on two things:

  1. As I’ve learned, there’s just about to be an under-the-hood transition between PyMC3 and PyMC4 (not yet released). Any advice on which I should target? Perhaps relevant is that MUSE really shines on high dimensional problems, so perhaps the more performant PyMC4 would make sense? If the answer is that I may as well go for PyMC4, should I just check out the master branch of pymc and that’s it? Is there any other dependent packages I need?

  2. Any advice or examples I can follow on writing a custom “sampler”. MUSE ultimately just needs forward simulations and gradients of the joint likelihood, and is otherwise quite simple.

Thanks for any help.

3 Likes

Very interesting, and we would love to help, but to better advise maybe you could provide a bit of pseudocode as an example?

A short answer for now:

  1. PyMC4 is no more and moving forward everything will be in GitHub - pymc-devs/pymc: Probabilistic Programming in Python: Bayesian Modeling and Probabilistic Machine Learning with Aesara (aka PyMC v4 will be the next major release).
  2. For that I would like to understand a bit more of how the algorithm works, from MuseInference.jl · MuseInference it seems it requires prior log_prob for the \theta, joint log likelihood and its gradient to get the Score Marginal likelihood function, and a prior simulation function to get the approximation for the score function - I think the closest thing might be the implementation for Sequential Monte Carlo which isolate similar component from a PyMC model: pymc/pymc/smc at main · pymc-devs/pymc · GitHub

Thanks for the reply, here’s the basic of it. The goal is to come up with a Gaussian approx to the marginal posterior P(θ|x) = ∫dz P(θ,z|x). The following is the mean (the covariance of the estimate is different but uses the identical pieces):

θ = # initial guess for θ
H = # some guess for Hessian of θ -> logP(θ|x)

while norm(θ - θlast) < θtol:

    # a bunch of simulated x's generated from P(x,z|θ)
    x_sims = [sample_prior_predictive(θ).x for i in 1:nsims] 

    # zMAP maximizes the function z -> logP(x,z|θ)
    zMAP_data = zMAP(x, θ)
    zMAP_sims = [zMAP(x_sim, θ) for x_sim in x_sims]

    # ∇θ_logLike is gradient of θ -> logP(x,z|θ)
    g_data = ∇θ_logLike(zMAP_data, x, θ)
    g_sims = [∇θ_logLike(zMAP_sim, x_sim, θ) for (zMAP_sim, x_sim) in zip(zMAP_sims,x_sims)]

    # gradient of θ -> logP(θ)
    g_prior = ∇θ_logPrior(θ) 

    θlast = θ
    θ -= H \ (g_data - mean(g_sims) + g_prior)

The one thing I’ve already figured out what its called in PyMC is sample_prior_predictive, although I haven’t found the right way to condition it on θ. The gradient of the likelihood with respect to the latent z would be used in zMAP(), then I also need ∇θ_logLike() and ∇θ_logPrior() to evaluate the gradient of the likelihood and prior w.r.t. θ. That’s everything.

Yes sorry, I guess my question really boils down to, would I expect that the code I write here is compatible with both the current latest release and the future PyMC v4, or is the relevant API changing such that I’d need to update / support both versions?

And thanks for the SMC example, am taking a look now.

I think you should prioritize implementing it in PyMC v4, since this is new feature back porting it to current major release is very necessary.

Now to the approximation itself, it is a pretty interesting algorithm so I spent some time to draft out some components - this should help you implement it.
Some notes:

  1. Forward sampling is more straightforward in PyMC v4, although conditioning is a feature we are working on (see do operator / conditioning · Discussion #5280 · pymc-devs/pymc · GitHub).
  2. Getting the log likelihood function that can condition on new x_sim takes a bit of effort (thanks @ricardoV94 for guiding me how to do this correctly in v4), tldr is that you need to point to the correct random variable in PyMC
  3. Inference is done in the transformed space (unbounded), so we need a way to link the original space and unbounded space for forward sample.

I use your Neal’s Funnel example, details are in notebook: Planet_Sakaar_Data_Science/discourse_8528 (MUSE).ipynb at main · junpenglao/Planet_Sakaar_Data_Science · GitHub

(It is not quite work yet, I guess it depends on how to specify and update the Hessian H)

Side question on the approximation: while your paper discuss a bit of the parallel and differences with Laplace Approximation, I am wondering the comparison between MUSE and INLA (which try to do Laplace Approximation to Integrate z). INLA also require the latent field z being Gaussian like, and I find that in MUSE you stated “Even for mildly non-Gaussian latent spaces, one expects the data dependence of the integral to be small, with most of the data-dependence instead captured by the MAP term.” - do you have any comment on wildly non-Gaussian latent space (e.g. mixture)?

3 Likes

@junpenglao Thanks, this an amazing amount of work you’ve put into this! And PyMC stuff that would have taken me ages to figure out! I’m going through the notebook now and will see if I can get it fully working and matching the result of the Julia package, I think you’re like 90% of the way there. Will post back here and can also PR your notebook.

Noted about PyMC v4, that seems right, especially since I see a beta is out now. Re: INLA, thanks for pointing me to that, didn’t know about it. Took a quick glance, although I need more time to understand it fully. One first thing that stuck out is that it seems it requires the observed variables to be conditionally independent, whereas MUSE has no such requirement (granted it is the case for that toy funnel problem, but e.g. not for the “CMB lensing” real-world problem from the paper). On how well it works for non-Gaussian latent spaces, we showed that non-Gaussianity doesn’t the change the asymptotic unbiasedness of MUSE, it might just make it suboptimal. If by mixture you mean a multi-modal latent space, I’m not sure we’ve thought about that carefully enough though, but its an interesting question. I think there’s a sense in which you can imagine the estimate works with that, since intuitively a bunch of the MAP’s for the different sims that are part of the algorithm will fall into different maxima. So it seems it can kind of account for multi-modality, but I can’t say anything more quantiative than that.

1 Like

Awesome - INLA like idea is an area we would love to have more implementation - similar paper in case you have not seen it is: [2004.12550] Hamiltonian Monte Carlo using an adjoint-differentiated Laplace approximation: Bayesian inference for latent Gaussian models and beyond

Again, can’t thank you enough, seeing that notebook was invaluable and rapidly got me up to speed on Aesara stuff I needed to know.

Based on that, I’ve got the skeleton of a package going. I fixed one algorithmic issue with what you had, which was that the \frac{d}{d\theta} \log \mathcal{P}(x,z\,|\,\theta) was picking out the wrong term in the model (it was using only the \theta “node” but actually its the z one thats important). That’s likely why it wasn’t working for you. You can see a demo of early API here:

https://cosmicmar.com/muse_inference/demo.html#With-PyMC

with the code here: muse_inference/pymc.py at main · marius311/muse_inference · GitHub

I had a couple of followup detailed questions as I’ve gotten more familiar with stuff:

  1. Is there a way to get the observed value out of a RV, so as to avoid the user needing to do the prob.x = x_obs line?
  2. Is there a way to pass a RandomState-type thing to the sampling functions to make things reproducible? Figured out how to set the Model.rng_seq as needed to make sample_prior_predictive reproducible (thats the only sampler I needed)
  3. As you can see on that page, right now the PyMC version is 30x slower than Numpy and 5x slower than Jax. I’ve done absolute no profiling yet so its maybe too early, but I was wondering if anything glaringly obvious stands out?

Still plenty of work left, including the vector concat-ing stuff from your notebook which I still need to work through.

2 Likes

Nice!

Yes, you can do x = [model.rvs_to_values[v].data for v in model.observed_RVs]

Nothing obvious from a quick glance of the implementation (although as you said there is still the concating and transforming the variables for the implementation to work on a general PyMC model). Usually for small functions Aesara should be faster than Jax, so if you could profile the function call and we can find out which is the slow one and try to find solution.

1 Like