Variational inference: diagnosing convergence

The main problems with standard ADVI as defined in Kucukelbir et al.'s paper are:

  1. High variance of stochastic gradients—upping this to 5K or 50K helps immensely with optimization, but you’ll need a GPU to make that tractable.

  2. Bad step size adaptation—you can try a grid of step sizes in parallel, but you have to be careful to get step sizes that work in the initial iterations and the iterations nearer convergence.

  3. Bad reparameterization gradient for termination—if you’re not using a huge number of stochastic gradient iterations for the ELBO (approximate negative KL), using the stick-the-landing reparameterization gradient is more effective than the simple approach in the paper.

  4. As a termination criterion, using the ELBO can be tricky because it tends to flatten out quite a bit before actual convergence—you can see this in the OP’s plots. You can monitor the norm of the parameters themselves across iterations or use something like Wasserstein distance, but that’s super expensive.

There are a couple of good references on this:

As a side note, you can often get just as good an approximation as ADVI if not better using 100 iterations of NUTS.

I think the typical approach in big-data regimes is to use a holdout set, and terminate training when the test loss starts to U-turn.

You can’t quite do that in ML because of double descent. There’s also no natural holdout set when just given a black box target density with gradients. If you have a structured model where it makes sense to do posterior predictive inference for held out data, you could use that to mirror what the ML folks do. But then you’re not necessarily fitting the model you want to fit, but a version that’s regularized with early stopping.

4 Likes