Variational inference: diagnosing convergence

Slightly tangential, but learning rate schedulers for ADVI have a PR that is just waiting for someone to come along and get it back on track :slight_smile:

Some more general comments:

  • Convergence detection when doing SGD is a pretty complex subject in the larger ML world, no? I think the typical approach in big-data regimes is to use a holdout set, and terminate training when the test loss starts to U-turn.
  • The PyMC “default” convergence check that you get when you do init="advi+adapt_diag is given here, as:
   cb = [
       pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
       pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
   ]

Rightly or wrongly, I’ve viewed this as a reasonable “prior” on how to check for convergence.