ADVI changes start values to NaN

I get the error FloatingPointError: NaN occurred in optimization. in the fist iteration of ADVI despite a good starting point:

with model:
    vinfer = pm.ADVI(obj_optimizer=pm.adam(learning_rate=1e-6))
    ...
    vinfer.approx.params[0].set_value(vinfer.approx.bij.map(model_means))
    vinfer.approx.params[1].set_value(vinfer.approx.bij.map(model_rhos))
    point = vinfer.approx.groups[0].bij.rmap(vinfer.approx.params[0].eval())
    print(point['w_f_stickbreaking']) # here the values seem fine
    approx = vinfer.fit(
        n=int(maxIter),
        obj_optimizer=obj_optimizer,
        progressbar=True,
    )

returns

[-2.06601772 -2.11281976 -2.15386025 -2.18520134 -2.25892385  4.07481609
  4.76060229  4.16298305  1.28004891]

0.00% [0/1000000 00:00<00:00]

---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-57-39092e69cb13> in <module>
     34     point = vinfer.approx.groups[0].bij.rmap(vinfer.approx.params[0].eval())
     35     print(point['w_f_stickbreaking'])
---> 36     aprox = vinfer.fit(
     37         n=int(maxIter),
     38         callbacks=[sf],

~/.conda/envs/lastEnv/lib/python3.8/site-packages/pymc3/variational/inference.py in fit(self, n, score, callbacks, progressbar, **kwargs)
    150             progress = range(n)
    151         if score:
--> 152             state = self._iterate_with_loss(0, n, step_func, progress, callbacks)
    153         else:
    154             state = self._iterate_without_loss(0, n, step_func, progress, callbacks)

~/.conda/envs/lastEnv/lib/python3.8/site-packages/pymc3/variational/inference.py in _iterate_with_loss(self, s, n, step_func, progress, callbacks)
    238                     except IndexError:
    239                         pass
--> 240                     raise FloatingPointError("\n".join(errmsg))
    241                 scores[i] = e
    242                 if i % 10 == 0:

FloatingPointError: NaN occurred in optimization. 
The current approximation of RV `w_t_stickbreaking`.ravel()[0] is NaN.
The current approximation of RV `w_t_stickbreaking`.ravel()[1] is NaN.
The current approximation of RV `w_t_stickbreaking`.ravel()[2] is NaN.
The current approximation of RV `w_t_stickbreaking`.ravel()[3] is NaN.
The current approximation of RV `w_t_stickbreaking`.ravel()[4] is NaN.
The current approximation of RV `w_t_stickbreaking`.ravel()[5] is NaN.
The current approximation of RV `w_t_stickbreaking`.ravel()[6] is NaN.
The current approximation of RV `w_t_stickbreaking`.ravel()[7] is NaN.
The current approximation of RV `w_t_stickbreaking`.ravel()[8] is NaN.
The current approximation of RV `w_f_stickbreaking`.ravel()[0] is NaN.
The current approximation of RV `w_f_stickbreaking`.ravel()[1] is NaN.
The current approximation of RV `w_f_stickbreaking`.ravel()[2] is NaN.
The current approximation of RV `w_f_stickbreaking`.ravel()[3] is NaN.
The current approximation of RV `w_f_stickbreaking`.ravel()[4] is NaN.
The current approximation of RV `w_f_stickbreaking`.ravel()[5] is NaN.
The current approximation of RV `w_f_stickbreaking`.ravel()[6] is NaN.
The current approximation of RV `w_f_stickbreaking`.ravel()[7] is NaN.
The current approximation of RV `w_f_stickbreaking`.ravel()[8] is NaN.
Try tracking this parameter: http://docs.pymc.io/notebooks/variational_api_quickstart.html#Tracking-parameters

and when I check the starting value again it seems broken:

>>> point = vinfer.approx.groups[0].bij.rmap(vinfer.approx.params[0].eval())
>>> print(point['w_f_stickbreaking']) # now the values seem broken
[nan nan nan nan nan nan nan nan nan]

This happens even with the reduced learning rate of adam. Trying other optimizers such as adagrad_window did also not resolve the error.

The variables in question w_t_stickbreaking and w_t_stickbreaking do both come from pm.Normal.

@junpenglao s comment here may suggest, this could be an issue of setting up the approximation score function.

However, where would I start debugging this and why does ADVI ever change its parameters before the first iteration?

Definitely not an answer, more of a general comment- Iā€™m curious why PyMC3 is gaining interest in variational inference; outside of Bayesian NNs, I believe that NUTS/HMC is still preferred for Bayesian inference. So Iā€™m curious, will PyMC3 move to include Bayesian NNs in its medium-long range strategy?

@jbuddy_13 I can only speak for myself: When a model has a lot of parameters and the posterior of each parameter can be expected to be approximately normal, then ADVI is much quicker than sampling. If needed, the result can be refined through sampling by drawing start values from the ADVI posterior. I find this to be convenient in many applications. However, you may find better answers in a separate discussion.

Furthermore, Bayesian NNs are already part of PyMC3: Variational Inference: Bayesian Neural Networks ā€” PyMC3 3.10.0 documentation

1 Like