Large Feature Number Dirichlet


I struggle to debug this seemingly simple model:

with pm.Model() as model:
    ratio = pm.Dirichlet('ratios', a=alpha, testval=start)
    obs = pm.Multinomial('observations', n=total, p=ratio, observed=local_counts)
    trace = pm.sample(draws=1000, cores=8)

It fails with

ValueError: Bad initial energy: inf. The model might be misspecified.

But uppon inspection I get a finite initial log-pd:

pri = pm.Dirichlet.dist(a=alpha).logp(start).eval()
obsp = pm.Multinomial.dist(n=total, p=start).logp(local_counts).eval()
pri + obsp.sum() # returns -2205368975.1055

or obs.logp(model.test_point) returns array(-2.20595029e+09).

The local_counts variable is rather large with a shape of (269, 58288) and is a numpy array of type int.

start = local_counts.sum(axis=0)
start = start / start.sum()
alpha = np.array([1]*local_counts.shape[1])
all(local_counts.sum(axis=1) == total) # returns True
theano.config.floatX # returns 'float64'

How should I debug this model? What am I overlooking?


Try trace = pm.sample(draws=1000, cores=8, init='adapt_diag') The default jitter might put the starting values to a invalid point.


It still fails similarly:

Auto-assigning NUTS sampler...
INFO:pymc3:Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
INFO:pymc3:Initializing NUTS using adapt_diag...
Multiprocess sampling (8 chains in 8 jobs)
INFO:pymc3:Multiprocess sampling (8 chains in 8 jobs)
NUTS: [ratios]
INFO:pymc3:NUTS: [ratios]
Sampling 8 chains:   0%|          | 0/12000 [00:00<?, ?draws/s]
RemoteTraceback                           Traceback (most recent call last)
Traceback (most recent call last):
  File "/homes/olymp/dominik.otto/Envs/noteEnv3/lib/python3.6/site-packages/pymc3/", line 73, in run
  File "/homes/olymp/dominik.otto/Envs/noteEnv3/lib/python3.6/site-packages/pymc3/", line 113, in _start_loop
    point, stats = self._compute_point()
  File "/homes/olymp/dominik.otto/Envs/noteEnv3/lib/python3.6/site-packages/pymc3/", line 139, in _compute_point
    point, stats = self._step_method.step(self._point)
  File "/homes/olymp/dominik.otto/Envs/noteEnv3/lib/python3.6/site-packages/pymc3/step_methods/", line 247, in step
    apoint, stats = self.astep(array)
  File "/homes/olymp/dominik.otto/Envs/noteEnv3/lib/python3.6/site-packages/pymc3/step_methods/hmc/", line 117, in astep
    'might be misspecified.' %
ValueError: Bad initial energy: inf. The model might be misspecified.

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
ValueError: Bad initial energy: inf. The model might be misspecified.

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-72-c91b37d83866> in <module>
      2     ratio = pm.Dirichlet('ratios', a=alpha, testval=start)
      3     obs = pm.Multinomial('observations', n=total, p=ratio, observed=local_counts)
----> 4     trace = pm.sample(draws=1000, cores=8, init='adapt_diag')

~/Envs/noteEnv3/lib/python3.6/site-packages/pymc3/ in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
    447             _print_step_hierarchy(step)
    448             try:
--> 449                 trace = _mp_sample(**sample_args)
    450             except pickle.PickleError:
    451                 _log.warning("Could not pickle model, sampling singlethreaded.")

~/Envs/noteEnv3/lib/python3.6/site-packages/pymc3/ in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, use_mmap, **kwargs)
    997         try:
    998             with sampler:
--> 999                 for draw in sampler:
   1000                     trace = traces[draw.chain - chain]
   1001                     if trace.supports_sampler_stats and draw.stats is not None:

~/Envs/noteEnv3/lib/python3.6/site-packages/pymc3/ in __iter__(self)
    304         while self._active:
--> 305             draw = ProcessAdapter.recv_draw(self._active)
    306             proc, is_last, draw, tuning, stats, warns = draw
    307             if self._progress is not None:

~/Envs/noteEnv3/lib/python3.6/site-packages/pymc3/ in recv_draw(processes, timeout)
    221         if msg[0] == 'error':
    222             old = msg[1]
--> 223             six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old)
    224         elif msg[0] == 'writing_done':
    225             proc._readable = True

~/Envs/noteEnv3/lib/python3.6/site-packages/ in raise_from(value, from_value)

RuntimeError: Chain 2 failed.


Sounds like there is something wrong with the gradient - try following the steps in How to track a 'nan energy'?


I ran the steps and got

{'ratios_stickbreaking__': array([ 0.98842164, -3.49124823, -0.05789614, ..., -7.64635372,
       -8.74488638,  2.23058342])}
[ 0.33077722 -0.60774967 -0.96290679 ... -0.89012621 -0.32552028
[-1.68691344  0.96953555  0.05625001 ...  0.99936306  0.99976111
[ 0.33077722 -0.60774967 -0.96290679 ... -0.89012621 -0.32552028
29253.630307597938 is inf and logp is -inf while all(np.isfinite(p0)) & all(np.isfinite(dlogp)) is True. So the step function seems to have a different initial logp which is not finite. How does this quantity differ from the ordinary logp of my model?

The problem might be in step._logp_dlogp_func, since all(np.isfinite(q0)) is False.


If all(np.isfinite(q0)) is False, it would means some starting value is non-finite. Sometimes it might not be visible in the test point, as in this case the free parameters are actually the transformed version of the Dirichlet. Which means internally, the test_point is transformed (on to the real domain) and then used as starting value (q0 in this case).

My guess is that, here the number of categories are too large (269), which makes some parameter extremely small, which in turn gives numerical error.


You are right. I investigated the test value for the Dirichlet distribution and saw that all(start[:-1][~np.isfinite(q0)] == 0) is True. It does not really make sense to have any p_i to be exactly zero. So I chose a nicer start value:

start = local_counts.sum(axis=0) + 1
start = start / start.sum()

Now the sampling seems to be running nicely.

Thank you very much for the support! I will keep this helpful routine in mind for future debugging :wink: