Numerical Issues with StickBreaking in ADVI

I investigated some suspicions speratly:

The Multinomial Distribution

Suspicious about the numerical accuracy of pm.Multinomial, when used for the observations, I replaced the line

obs = pm.Multinomial('obs', np.sum(sample), combined, observed=sample)

with

mdist = pm.Dirichlet.dist(sample+1)
pot = pm.Potential('obs', mdist.logp(combined))

This implementation yields equivalent results in theory since the Dirichlet distribution is conjugate to the Multinomial distribution. Contrary to expectations, the numerical results are equivalent as well:

[9.81476523e-01 1.10539880e-02 3.41269485e-03 6.53532756e-04
 5.42709558e-04 5.16784915e-04 5.39402198e-04 5.59363132e-04
 6.12063159e-04 6.32938071e-04]

Mixing of Softmax

tt.nnet.softmax maps into the simplex and the weighted average realized through the dot-product with decomp\sim Dirichlet, should also be in the simplex. So mix should map into the simplex as well. Since the logp of combined is bound to the simplex, it can easily produce divergent samples that not quiet lie within it. To correct any numerical errors of the mix function I forced the result to lie within the simplex more directly with:

def mix(components, decomp):
    result = tt.dot(decomp[None, :], tt.nnet.softmax(components))
    result = tt.switch(result>0, result, 0)
    result /= tt.sum(result)
    return result

but the results are still biased

[9.92594050e-01 9.99273530e-04 1.03004292e-03 1.15697420e-03
 6.35424208e-04 6.73484137e-04 6.65853911e-04 6.97714797e-04
 7.70208461e-04 7.76973765e-04]

Edit: I fixed a mistake in mix and the bias came back. The last result was a random one since result was mapped to tt.ones(10).

2 Likes