Numerical Issues with StickBreaking in ADVI

You may be right. I had to correct my last report and now I am back at the start… I attempted to use my own Stckbreaking:

class ownStickBreaking(pm.distributions.transforms.transform):
    """
    Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of real values.
    """

    name = "stickbreaking"

    def forward(self, x_):
        x = x_.T
        n = x.shape[0]
        lx = tt.log(x)
        shift = tt.sum(lx, 0, keepdims=True) / n
        y = lx[:-1] - shift
        return floatX(y.T)

    def forward_val(self, x_):
        x = x_.T
        n = x.shape[0]
        lx = np.log(x)
        shift = np.sum(lx, 0, keepdims=True) / n
        y = lx[:-1] - shift
        return floatX(y.T)

    def backward(self, y_):
        y = y_.T
        y = tt.concatenate([y, -tt.sum(y, 0, keepdims=True)])
        # "softmax" with vector support and no deprication warning:
        e_y = tt.exp(y - tt.max(y, 0, keepdims=True))
        x = e_y / tt.sum(e_y, 0, keepdims=True)
        return floatX(x.T)

    def backward_val(self, y_):
        y = y_.T
        y = np.concatenate([y, -np.sum(y, 0, keepdims=True)])
        e_y = np.exp(y - np.max(y, 0, keepdims=True))
        x = e_y / np.sum(e_y, 0, keepdims=True)
        return floatX(x.T)
    
    def jacobian_det(self, x_):
        x = x_.T
        n = x.shape[0]
        sx = tt.sum(x, 0, keepdims=True)
        r = tt.concatenate([x+sx, tt.zeros(sx.shape)])
        # stable according to: http://deeplearning.net/software/theano_versions/0.9.X/NEWS.html
        sr = tt.log(tt.sum(tt.exp(r), 0, keepdims=True))
        d = tt.log(n) + (n*sx) - (n*sr)
        return d.T

and use it in

decomp = pm.Dirichlet('decomp', np.ones(10), shape=10, transform=ownStickBreaking())

but the result is still heavily biased:

[3.82887573e-02 5.27427015e-02 3.37799631e-01 5.06014240e-02
 4.40718641e-02 4.55886969e-02 5.35025041e-02 3.76248055e-01
 2.89980647e-04 8.66385293e-04]

And sampling from it still produces divergences:

>>> trace = pm.sample(model=model)
Sampling 4 chains, 154 divergences: 100%|██████████| 4000/4000 [05:52<00:00,  2.54draws/s]