Numerical Issues with StickBreaking in ADVI

Explenation

Using the Dirichlet distribution to model linear mixtures with ADVI results in seemingly biased decompositions. The following toy example mixes ten equivalent components. I expected the resulting partition of one (decomp) to be somewhat close to [.1]*10. Instead, it accumulates almost all its mass in the first entry, while the other entries are very close to zero. What seems to be a bias towads the edge of the parameter space becomes stronger the longer one iterates advi and can even throw off models, that produce otherwise robust results.

Example

import pymc3 as pm
import numpy as np
import theano
import theano.tensor as tt
from pymc3.distributions.transforms import t_stick_breaking

np.random.seed(1)
sample = np.random.randint(0, 1e5, 10)
def mix(components, decomp):
    return tt.dot(decomp[None, :], tt.nnet.softmax(components))

with pm.Model() as model:
    decomp = pm.Dirichlet('decomp', np.ones(10), shape=10,
                         transform=t_stick_breaking(1e-9))
    components = [pm.Normal(str(i), shape=sample.shape) for i in range(10)]
    components = tt.stack(components, axis=0)
    combined = pm.Deterministic('combined', mix(components, decomp))
    obs = pm.Multinomial('obs', np.sum(sample), combined, observed=sample)
    mean_field = pm.fit(method='advi', n=int(1e5), progressbar=False)
decomp = mean_field.bij.rmap(mean_field.mean.get_value())

print(theano.config.floatX)
print(t_stick_breaking(1e-9).backward(decomp['decomp_stickbreaking__']).eval())

Output

Finished [100%]: Average Loss = 168.91
float64
[9.88788496e-01 3.56906299e-03 1.81464568e-03 2.00007439e-03
 6.65347116e-04 6.24080247e-04 5.88694760e-04 5.96708122e-04
 6.79652121e-04 6.73238123e-04]

Expected Output

Something close to [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1].

Question

Is this really a numerical issue or a mistake on my side? And is there a way to fix the example model without changing the theoretical result?

Related Things

Versions

pymc3==3.7 (current master https://github.com/pymc-devs/pymc3/commit/e3b667c7515e5519f8afe711d6d5723c65ee0311)
Theano==1.0.4
numpy==1.17.0
Ubuntu 18.04.3 LTS

1 Like

The problem also seems to be related to the number of multinomial trials. Using 10 instead of 1e5 in sample = np.random.randint(0, 1e5, 10) produces much less bias. But already a value of 100 seems critical.

I replaced the StickBreaking with something that is less convoluted and should be numerically fine: https://github.com/pymc-devs/pymc3/pull/3620

However, the problem stays the same.

I think the problem may arise from the model itself. The theoretical mean of the decomp is [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1] but the distribution is multimodal and hence cannot be approximated well with the default transformation used by ADVI (presumably StickBreaking). The multimodality comes from the fact, that the extreme observation sample cannot be explained well by either of the components. In turn, it is more likely for one single component to assume a very unlikely state that approximates sample instead of all of them.

I could do one of these three things to resolve the issue:

  • use a different parameter distribution in the variational inference, e.g. the particle swarm in svgd.
  • find an alternative StickBreaking transformation that maps decomp to a monomodal distribution that can be approximated well with a normal distribution
  • change the model to better explain the observation sample

After further diagnoses, it turns out the distribution of decomp is not multimodal. Using the NUTS sampler reveals distributions of decomp_stickbreaking__ that could be aproximated nicely by ADVI with a normal distribution:

>>> trace = pm.sample(model=model, draws=5000)
>>> ax = pd.DataFrame(trace['decomp_stickbreaking__']).plot.kde()
>>> ax.set_xlim(-10, 10)

image
The transformed decomp_stickbreaking__ seems biased but this may be due to the Stickbreaking transformation as decomp is as desired:

>>> np.mean(trace['decomp'], axis=0)
array([0.10333791, 0.09897287, 0.10272751, 0.09945381, 0.09798678,
       0.09716101, 0.09705229, 0.09979954, 0.10609969, 0.09740859])

However, sampling with NUTS produced some divergences:

>>> trace['diverging'].nonzero()[0].size
106

The points of divergence seem to be at low values of the first coordinates of decomp:

def pairplot_divergence(trace, var1, var2, i1=0, i2=0):
    v1 = trace.get_values(varname=var1, combine=True)[:, i1]
    v2 = trace.get_values(varname=var2, combine=True)[:, i2]
    _, ax = plt.subplots(1, 1, figsize=(10, 5))
    ax.plot(v1, v2, 'o', color='b', alpha=.5)
    divergent = trace['diverging']
    ax.plot(v1[divergent], v2[divergent], 'o', color='r')
    ax.set_xlabel('{}[{}]'.format(var1, i1))
    ax.set_ylabel('{}[{}]'.format(var2, i2))
    ax.set_title('scatter plot between {}[{}] and {}[{}]'.format(var1, i1, var2, i2));
    return ax
pairplot_divergence(trace4, 'decomp', 'decomp', i1=0, i2=1)

image
In the image above, that looks at the first two coordinates of decomp, the divergences are close to the origin as opose to the image below, where we look at the last two coordinates of decomp.

pairplot_divergence(trace4, 'decomp', 'decomp', i1=8, i2=9)

image
However, looking at the transformed values does not reveal such a tendency:

pairplot_divergence(trace4, 'decomp_stickbreaking__', 'decomp_stickbreaking__', i1=0, i2=8)

image
So I still suspect the StickBreaking transformation itselfe to introduce some numerical difficulties.

Increase the target_accept and additional tuning, hoever, seems to remove this issue entirely:

>>> trace2 = pm.sample(model=model, draws=50000, tune=2000, target_accept=.99)
>>> trace2['diverging'].nonzero()[0].size
0

One could argue that extremely large gradients throw off the stochastic optimization of ADVI as well. The high sensitivity of the model to the choice of the ADVI optimizer further supports this theory. I was able to eliminate the divergences in the NUTS sampling by increasing the target acceptance rate and hence decrease the step size. Equivalently using a much more conservative stochastic optimizer seems to mediate the problem:

>>> mean_field2 = pm.fit(model=model, method='advi', n=int(1e6), progressbar=False,
...                      obj_optimizer=pm.adamax(learning_rate=1e-4, beta1=1-1e-4, beta2=1-1e-5))
>>> decomp = mean_field2.bij.rmap(mean_field2.mean.get_value())
>>> print(StickBreaking().backward(decomp['decomp_stickbreaking__']).eval())
[0.06584135 0.06966171 0.07110197 0.07978009 0.08543297 0.08582187
 0.11319634 0.11781051 0.15662726 0.15472593]

There is still a notable bias in the result, so I would welcome any suggestions how to improve it.

Thanks for the follow up and the detail write up, I still need to have a closer look at your model and posts, but a quick note: did you try the sumto1 transformation? It is not the same as stickbreaking but it might worth to try as well.

I did not know there was such a transformation. Looking at the code (https://github.com/pymc-devs/pymc3/blob/05e95dc01eee194a8e67e4e7a987906cf21cc8b2/pymc3/distributions/transforms.py#L401-L425) it seems the transformed values have some constraints, .e.g., beeing positive, and sum to something below 1. Using this in ADVI and the usual normal distributions to model the parameter distribution would violate the support constraint. That is the support of the fitted parameter distribution must coincide with the support of the posterior.

Using this transformation, SumTo1, with ADVI immediately yields the error:

FloatingPointError: NaN occurred in optimization. 

That’s the same constraint of the stickbreaking. Basically what transformation is doing is to transfer variable from Real (the parameter that actually being optimized/sampled) to a constrained space (the parameters that actually got plugged into logp: logp(value))

Ok, but the the backward method of SumTo1 does not seem to map the whole \mathbb R^{n-1} into the simplex. E.g.:

>>> SumTo1().backward_val(np.ones(2))
array([ 1.,  1., -1.])

And the support of logp() of the Dirichlet distribution is only the simplex, while ADVI models a paramter distribution on \mathbb R^{n-1}.

You are right, sum_to_1 only produce unit vector but not simplex.

I am not sure if that really solved anything or just made the divergence slower. Using it in my larger model does not seem to improve the results. It only takes much longer until the output starts to become weird.

I investigated some suspicions speratly:

The Multinomial Distribution

Suspicious about the numerical accuracy of pm.Multinomial, when used for the observations, I replaced the line

obs = pm.Multinomial('obs', np.sum(sample), combined, observed=sample)

with

mdist = pm.Dirichlet.dist(sample+1)
pot = pm.Potential('obs', mdist.logp(combined))

This implementation yields equivalent results in theory since the Dirichlet distribution is conjugate to the Multinomial distribution. Contrary to expectations, the numerical results are equivalent as well:

[9.81476523e-01 1.10539880e-02 3.41269485e-03 6.53532756e-04
 5.42709558e-04 5.16784915e-04 5.39402198e-04 5.59363132e-04
 6.12063159e-04 6.32938071e-04]

Mixing of Softmax

tt.nnet.softmax maps into the simplex and the weighted average realized through the dot-product with decomp\sim Dirichlet, should also be in the simplex. So mix should map into the simplex as well. Since the logp of combined is bound to the simplex, it can easily produce divergent samples that not quiet lie within it. To correct any numerical errors of the mix function I forced the result to lie within the simplex more directly with:

def mix(components, decomp):
    result = tt.dot(decomp[None, :], tt.nnet.softmax(components))
    result = tt.switch(result>0, result, 0)
    result /= tt.sum(result)
    return result

but the results are still biased

[9.92594050e-01 9.99273530e-04 1.03004292e-03 1.15697420e-03
 6.35424208e-04 6.73484137e-04 6.65853911e-04 6.97714797e-04
 7.70208461e-04 7.76973765e-04]

Edit: I fixed a mistake in mix and the bias came back. The last result was a random one since result was mapped to tt.ones(10).

2 Likes

At this point, it looks to me more like a problem of stickbreaking.
I knew that I am missing one transformation that is also producing simplex :sweat_smile:

You may be right. I had to correct my last report and now I am back at the start… I attempted to use my own Stckbreaking:

class ownStickBreaking(pm.distributions.transforms.transform):
    """
    Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of real values.
    """

    name = "stickbreaking"

    def forward(self, x_):
        x = x_.T
        n = x.shape[0]
        lx = tt.log(x)
        shift = tt.sum(lx, 0, keepdims=True) / n
        y = lx[:-1] - shift
        return floatX(y.T)

    def forward_val(self, x_):
        x = x_.T
        n = x.shape[0]
        lx = np.log(x)
        shift = np.sum(lx, 0, keepdims=True) / n
        y = lx[:-1] - shift
        return floatX(y.T)

    def backward(self, y_):
        y = y_.T
        y = tt.concatenate([y, -tt.sum(y, 0, keepdims=True)])
        # "softmax" with vector support and no deprication warning:
        e_y = tt.exp(y - tt.max(y, 0, keepdims=True))
        x = e_y / tt.sum(e_y, 0, keepdims=True)
        return floatX(x.T)

    def backward_val(self, y_):
        y = y_.T
        y = np.concatenate([y, -np.sum(y, 0, keepdims=True)])
        e_y = np.exp(y - np.max(y, 0, keepdims=True))
        x = e_y / np.sum(e_y, 0, keepdims=True)
        return floatX(x.T)
    
    def jacobian_det(self, x_):
        x = x_.T
        n = x.shape[0]
        sx = tt.sum(x, 0, keepdims=True)
        r = tt.concatenate([x+sx, tt.zeros(sx.shape)])
        # stable according to: http://deeplearning.net/software/theano_versions/0.9.X/NEWS.html
        sr = tt.log(tt.sum(tt.exp(r), 0, keepdims=True))
        d = tt.log(n) + (n*sx) - (n*sr)
        return d.T

and use it in

decomp = pm.Dirichlet('decomp', np.ones(10), shape=10, transform=ownStickBreaking())

but the result is still heavily biased:

[3.82887573e-02 5.27427015e-02 3.37799631e-01 5.06014240e-02
 4.40718641e-02 4.55886969e-02 5.35025041e-02 3.76248055e-01
 2.89980647e-04 8.66385293e-04]

And sampling from it still produces divergences:

>>> trace = pm.sample(model=model)
Sampling 4 chains, 154 divergences: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4000/4000 [05:52<00:00,  2.54draws/s]

I have no insight of why exactly VI would gives such horrible result. I suspect it is something to do with multi modal as well (for component) or a local minima.
For what is worth, reduce the number of training epoch and changing to another optimizer might help:
[Edit]: this doesnt seems to be a good solution, see below post.

import pymc3 as pm
import numpy as np
import theano
import theano.tensor as tt
from pymc3.distributions.transforms import t_stick_breaking

np.random.seed(1)
nd = 10
sample = np.random.randint(0, 10000, nd)
def mix(components, decomp):
    return tt.dot(decomp, tt.nnet.softmax(
        tt.horizontal_stack(tt.zeros((nd, 1)), components)))
    
with pm.Model() as model:
    decomp = pm.Dirichlet('decomp', np.ones(10), shape=(1, 10),
                         transform=t_stick_breaking(1e-9))
    components = pm.Normal('components', shape=(nd, nd-1))
    combined = pm.Deterministic('combined', mix(components, decomp))
    obs = pm.Multinomial('obs', np.sum(sample), combined, observed=sample)
    mean_field = pm.fit(method='advi', n=int(1e4), obj_optimizer=pm.adam(),
                        progressbar=False)
decomp = mean_field.bij.rmap(mean_field.mean.get_value())

print(theano.config.floatX)
print(t_stick_breaking(1e-9).backward(decomp['decomp_stickbreaking__']).eval())

Note I also did some refactoring to make sure softmax doesnt make the model unidentified.

Looking at it closer, I am pretty sure there is a problem of multi modal:

pm.traceplot(trace, var_names=['decomp'], compact=False);

pm.pairplot(np.asarray(trace.get_values('decomp_stickbreaking__', combine=False)).squeeze());

2 Likes

Thank you for that inside.

What do you mean by softmax making the model unidentified? I can reproduce the multimodal results only if I replace one component by tt.zeros((nd, 1)) as you suggested. With the original model I get

pm.traceplot(trace, var_names=['decomp'], compact=False)

Softmax is something like softmax = lambda x: exp(x) / sum(exp(x)). Every time there is some kind of self normalization like dividing the sum it usually makes your model unidentifiable. You might not see it plotting the transformed variable decomp, but it is likely visible in the raw space decomp_stickbreaking__

1 Like

I do not quit get it. lambda x: softmax(StickBreacking().backward(x)) is an isomorphism \mathbb R^{n-1}\rightarrow\Delta^{n-1} and pm.Multinomial only takes p\in\Delta^{n-1}. So all latend variables should be defined by the observation. Additionally the uncorrected model does not show multimodality:

pm.pairplot(np.asarray(trace.get_values('decomp_stickbreaking__', combine=False)).squeeze());


I am sorry for my foolish stubbornness, but I really need to understand this.

It’s not - it just that in your case the Normal(0, 1) regulated it, if you set to components = pm.Normal('components', mu=0., sigma=100., shape=(nd, nd)) you can see quite clearly the problem introduce by the softmax. Also, note that the softmax is performed on the components, there is no StickBreacking().backward(x).

1 Like