Transport maps in PyMC

I’ve been studying the NeuTra paper and I think it would fit well within the PyMC API (automated reparameterization sound very “inference button”-y).

Can we include this kind of transport maps as a feature in PyMC? I would like to try to implement it, but I would like to hear from the community first. If so, how do you all think should the API be designed?

From a user perspective, I think it would make sense to have a pm.NeuTra function that performs the following steps:

  1. Train the inverse autoregressive flow (IAF) as described in the paper
  2. Create RV’s representing the z-space, with logp = lambda z: pm.Normal.dist(0,1).logp(z) + at.sum(sigma(z), where sigma is the function defined in the IAF
  3. RV’s instantiated by the user before step 2 (the actual parameters of interest) are redefined as the push-forward through the IAF.

So, the user would type something like this:

with pm.Model() as model:
    # the model is defined as usual, for example:
    sigma = pm.LogNormal("sigma", 0, 1)
    funnel = pm.Normal("funnel", 0, sigma, shape=10)
    # the next line trains the IAF neural network
    # and redefined the model RVs
    pm.NeuTra()
    # then, HMC sampling goes as usual.
    inference_data = pm.sample()

Also, apparently NeuTra has been implemented in NumPyro, but I’m still struggling to understand how it works.

Some references on NeuTra:

Excited to hear back from you all!

2 Likes

This sounds like an interesting case study. I am not sure about adding it to PyMC per se, but I am sure it would be neat as a standalone script/library. If it really proves useful, we also have https://github.com/pymc-devs/pymc-experimental

Bonus points if it’s implemented with PyMC V4 (it should actually be easier)

2 Likes

I see. I will start implementing it then, thank you!

2 Likes

I agree that this would be a nice contribution. More in the direction of VI but also very interesting in terms of inference would be an implementation of PathFinder: [2108.03782] Pathfinder: Parallel quasi-Newton variational inference

1 Like