SVGD NaN errors when ran with 1 particle


I am getting NaN errors while trying to fit a very simple model given below using SVGD with 1 particle. Any idea why?


x_obs = np.random.normal(loc=10, scale=1.0, size=1000)
with pm.Model() as mdl:
    mu  = pm.Normal('mu', mu=0.0, sd=1.0)
    x = pm.Normal('x', mu=mu, sd=1.0, observed=x_obs)
    optmzr = pm.SVGD(n_particles=1)
    approx =
    trace = approx.sample(100)

Error trace:

  0%|          | 0/10000 [00:00<?, ?it/s]

FloatingPointError                        Traceback (most recent call last)
<ipython-input-17-102169e6999f> in <module>()
      4     x = pm.Normal('x', mu=mu, sd=1.0, observed=x_obs)
      5     optmzr = pm.SVGD(n_particles=1)
----> 6     approx =
      7     trace = approx.sample(100)

~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/variational/ in fit(self, n, score, callbacks, progressbar, **kwargs)
    136                 state = self._iterate_with_loss(0, n, step_func, progress, callbacks)
    137             else:
--> 138                 state = self._iterate_without_loss(0, n, step_func, progress, callbacks)
    140         # hack to allow access to loss hist

~/anaconda3/envs/python3.6/lib/python3.6/site-packages/pymc3/variational/ in _iterate_without_loss(self, s, _, step_func, progress, callbacks)
    169                     except IndexError:
    170                         pass
--> 171                     raise FloatingPointError('\n'.join(errmsg))
    172                 for callback in callbacks:
    173                     callback(self.approx, None, i+s+1)

FloatingPointError: NaN occurred in optimization. 
The current approximation of RV `mu`.ravel()[0] is NaN.
Try tracking this parameter:


You absolutely need more than one particle for SVGD to work. We could maybe forbid n_particles=1

Is that an implementation limitation specific to pymc3? Because in the paper authors say it works as MAP estimate when using 1 particle.

Hmm, I may have spoken too soon. I’ll have to take a closer look at your problem

@lucianopaz Have you found any workaround to fix this?