Explenation
Using the Dirichlet distribution to model linear mixtures with ADVI results in seemingly biased decompositions. The following toy example mixes ten equivalent components. I expected the resulting partition of one (decomp
) to be somewhat close to [.1]*10
. Instead, it accumulates almost all its mass in the first entry, while the other entries are very close to zero. What seems to be a bias towads the edge of the parameter space becomes stronger the longer one iterates advi
and can even throw off models, that produce otherwise robust results.
Example
import pymc3 as pm
import numpy as np
import theano
import theano.tensor as tt
from pymc3.distributions.transforms import t_stick_breaking
np.random.seed(1)
sample = np.random.randint(0, 1e5, 10)
def mix(components, decomp):
return tt.dot(decomp[None, :], tt.nnet.softmax(components))
with pm.Model() as model:
decomp = pm.Dirichlet('decomp', np.ones(10), shape=10,
transform=t_stick_breaking(1e9))
components = [pm.Normal(str(i), shape=sample.shape) for i in range(10)]
components = tt.stack(components, axis=0)
combined = pm.Deterministic('combined', mix(components, decomp))
obs = pm.Multinomial('obs', np.sum(sample), combined, observed=sample)
mean_field = pm.fit(method='advi', n=int(1e5), progressbar=False)
decomp = mean_field.bij.rmap(mean_field.mean.get_value())
print(theano.config.floatX)
print(t_stick_breaking(1e9).backward(decomp['decomp_stickbreaking__']).eval())
Output
Finished [100%]: Average Loss = 168.91
float64
[9.88788496e01 3.56906299e03 1.81464568e03 2.00007439e03
6.65347116e04 6.24080247e04 5.88694760e04 5.96708122e04
6.79652121e04 6.73238123e04]
Expected Output
Something close to [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
.
Question
Is this really a numerical issue or a mistake on my side? And is there a way to fix the example model without changing the theoretical result?
Related Things
 Tutorial on AEVB for latent Dirichlet allocation with ADVI https://docs.pymc.io/notebooks/ldaadviaevb.html

eps
inmath.invlogit
https://github.com/pymcdevs/pymc3/issues/3001
Versions
pymc3==3.7 (current master https://github.com/pymcdevs/pymc3/commit/e3b667c7515e5519f8afe711d6d5723c65ee0311)
Theano==1.0.4
numpy==1.17.0
Ubuntu 18.04.3 LTS