How does pymc3 implement the Dirichlet distribution?

When I look at the source code of pymc3, I see that the logp of Dirichlet would be -np.inf if the random variable is out of bounds of the probability simplex. But -np.inf is not differentiable. How does Dirichlet deal with the out of bound situation?

How would the integrator and mass matrix adaptation work with a logp that is -np.inf?

The Dirichlet also use a “stick breaking” transformation. Does that mean reparametrizing the Dirichlet as independent Beta distributions that are bounded in the interval (0, 1)?

While the bound condition in the source code seems to constraint the value to be in the interval (0, 1), the comment states the bound constrains the sum of values to be 1. Would an additional re-parametrization to logits transform the (0, 1) constraint to a (-\infty, \infty) constraint?

Thanks again.


Cut & paste from pymc3 source code:


class Continuous(Distribution):
    def __init__(self, shape=(), dtype=None, defaults=('median', 'mean', 'mode'),
                 *args, **kwargs):
        if dtype is None:
            dtype = theano.config.floatX
        super().__init__(shape, dtype, defaults=defaults, *args, **kwargs)
        ....


class Dirichlet(Continuous):
    ....
        def __init__(self, a, transform=transforms.stick_breaking,
                 *args, **kwargs):
        shape = np.atleast_1d(a.shape)[-1]

        kwargs.setdefault("shape", shape)
        super().__init__(transform=transform, *args, **kwargs)
        ...
    def logp(self, value):
        k = self.k
        a = self.a

        # only defined for sum(value) == 1
        return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1)
                     + gammaln(tt.sum(a, axis=-1)),
                     tt.all(value >= 0), tt.all(value <= 1),
                     k > 1, tt.all(a > 0), broadcast_conditions=False)

def bound(logp, *conditions, **kwargs):
    broadcast_conditions = kwargs.get('broadcast_conditions', True)
    if broadcast_conditions:
        alltrue = alltrue_elemwise
    else:
        alltrue = alltrue_scalar
    return tt.switch(alltrue(conditions), logp, -np.inf)

Uniform sampling over a probability simplex with beta distribution:

Figure_1

import numpy as np

k = 3

alpha = np.arange(1, k)[::-1]
beta = np.ones(shape=k - 1)
samples = []
for i in range(1000):
    z = np.random.beta(alpha, beta)
    a = np.hstack([1, z])
    b = np.hstack([1 - z, 1])
    cp = np.cumprod(a)
    x = cp * b
    samples.append(x)

samples = np.vstack(samples)
print(samples)
print(np.sum(samples, axis=1))

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(samples[:, 0], samples[:, 1],samples[:, 2])
plt.show()

Take a look at the transform itself:

This actually parametrizes the k-simplex as a point in \mathbb{R}^{k+1} [edit: typo. \mathbb{R}^{k-1}

In the cases where parameters are Dirichlet-distributed; because of this parametrization there cannot be out-of-bounds problems.

On the other hand, if you have data which is Dirichlet-distributed (i.e. observed=), and pass in a point not on the simplex, you will generate a likelihood of -np.inf.

And if you did something crazy like

x = pm.Normal('x', 0.5, 0.1, shape=8)
lik = pm.Potential('lik', pm.Dirichlet(...).logp(x))

the -np.inf likelihoods would cause points out of the simplex to reject automatically.

1 Like

I think you meant that “This actually parametrizes the k-simplex as a point in R^{k-1}”

1 Like

What does forward and backward mean?

The forward and backward methods of StickBreaking are not inverse of each other.
A forward transform followed by a backward transform does not give the original value.

Test:

import numpy as np
from pymc3.distributions.transforms import StickBreaking

x = np.random.uniform(0, 1, (2, 4))
sb = StickBreaking()
x2 = sb.backward(sb.forward(x))

print(x)
print(x2.eval())

Result:

[[0.19679426 0.38085212 0.78022595 0.91964995]
 [0.37214554 0.91556628 0.09703552 0.76239552]]
[[0.08640717 0.16722213 0.34257665 0.40379405]
 [0.17332128 0.42641144 0.04519286 0.35507443]]

You should do x2 = sb.forward(sb.backward(x)). In pymc3 backward is mapping the input to unconstrained parameter space

Thanks for answering my question. I didn’t realize the order of composition matters. I think I tried to use an argument that is out of the domain of the function.

Test2:

import numpy as np
from pymc3.distributions.transforms import StickBreaking

probs = np.random.uniform(0, 9, (3,))
probs /= np.sum(probs)
sb = StickBreaking()

theta = sb.forward(probs).eval()
probs2 = sb.backward(theta).eval()

print("probs", probs)
print("probs2", probs2)

Result (equal as expected):

probs [0.26137719 0.47124916 0.26737366]
probs2 [0.26137719 0.47124916 0.26737366]

When I read the source again, I think the backward transforms from unconstrained to the probability simplex.
pymc3 source:

class StickBreaking(Transform):
    """
    Transforms K dimensional simplex space (values in [0,1] and 
    sum to 1) to K - 1 vector of real values.
    Primarily borrowed from the STAN implementation.
    """

Test:

import numpy as np
from pymc3.distributions.transforms import StickBreaking

theta = np.random.uniform(-9, 9, (3,))
sb = StickBreaking()

p1 = sb.backward(theta).eval()  
p2 = sb.forward(theta).eval()

print(p1, np.sum(p1))
print(p2, np.sum(p2))

Result:

# backward maps to probability simplex
[9.90819598e-01 1.54958553e-04 7.92423177e-05 8.94620147e-03] 1.0
# forward doesn't map to probability simplex
[        nan -0.33782225] nan

I am sampling over a set a parameters \theta. Backward transform of the parameters becomes a set of probabilities p, which I use to build the likelihood of the model \mathcal{L}(p).

When I don’t use a prior for \theta, p seems to have a higher density near the edge and corner of the probability simplex.

I want p to have prior distribution that is uniform over the probability simplex.

\vec{p} = backward(\vec{\theta}). The prior in terms of p is an uniform distribution over the probability simplex. I need someway to get the new prior in terms of \theta.

\vec{\theta} = forward(\vec{p}).

\Pr(\vec{p})d \vec{p} = \Pr(backward(\vec{\theta})) d \vec{p} = \Pr(backward(\vec{\theta}))\det[J] d\vec{\theta}

The StickBreaking class has a jacobian_det method that calculates \ln \lbrace \lvert\det[J(\vec{\theta})] \rvert \rbrace

\Pr(backward(\vec{\theta})) d\vec{\theta} = \Pr(\vec{p})d\vec{p} / \det[J(\vec{\theta})]

Then, \ln[\rho_{\theta}(\vec{\theta})] = \ln[\rho_{p}(\vec{p})] - \ln \lbrace \det[J(\vec{\theta})] \rbrace

Therefore, \ln[\rho_{\theta}(\vec{\theta})] = \ln[\rho_{p}(backward(\vec{\theta}))] - \ln \lbrace \det[J(\vec{\theta})] \rbrace, where \rho(\cdot) means a probability density function.

Since the original prior distribution is uniform over the probability simplex, \rho_{p}(\vec{p}) = \rho_{p}(backward(\vec{\theta})) is a constant.


Except there is a complication. If \vec{\theta} and \vec{p} have the same dimension, I can write:

J = \begin{bmatrix} \frac{\partial p_{1}}{ \partial \theta_{1} } & \frac{\partial p_{1}}{ \partial \theta_{2} } \\ \frac{\partial p_{2}}{ \partial \theta_{1} } & \frac{\partial p_{2}}{ \partial \theta_{2} } \\ \end{bmatrix}

However, when \vec{p} has n elements, \vec{\theta} only has n - 1 elements. The Jacobian matrix would be non-square. I can’t get the determinant of a rectangular matrix.

Thanks again.

By the way, does the jacobian_det method of the StickBreaking class implements \ln \lbrace \lvert\det[J(\vec{\theta})] \rvert \rbrace or negative of that? I guess it is actually the negative.

I directly sample a distribution over \theta with log-probability as `\text{jacobian_det}(\vec{\theta}). After I get the samples \vec{\theta}_{1}, \ldots, \vec{\theta}_{N}, I calculate \vec{p}_{n} = \text{backward}(\theta_{n}). Plotting \vec{p}_1, \ldots, \vec{p}_{N} gives an uniform distribution over the probability simplex.

The prior in terms of \theta that would transformed to an uniform prior over the probability simplex should also be \ln[p_{\theta}(\theta)] = \text{constant} - \ln{\det[J(\vec{\theta})]}, if my derivation in the last post is correct. Then, the jacobian_det method implements \text{constant} - \ln{\det[J(\vec{\theta})]}

Test:

import pymc3.distributions.transforms as pdt
import pymc3 as pm
import matplotlib.pyplot as plt
from scipy.stats import kde
import numpy as np


def plot(samples):
    x = samples[:, 0]
    y = samples[:, 1]

    nbins = 300
    k = kde.gaussian_kde([x, y])
    xi, yi = np.mgrid[x.min():x.max():nbins * 1j, y.min():y.max():nbins * 1j]
    zi = k(np.vstack([xi.flatten(), yi.flatten()]))

    # Make the plot
    plt.pcolormesh(xi, yi, zi.reshape(xi.shape))
    plt.show()



sb = pdt.StickBreaking()
unnormalized_log_prob = sb.jacobian_det

with pm.Model():
    theta_dist = pm.Uniform('theta', lower=-1e5, upper=1e5,
			    shape=(2,))

    # use a DensityDist
    pm.DensityDist('likelihood',
		   lambda theta : unnormalized_log_prob(theta),
		   observed={'theta': theta_dist})
    start = {'theta': np.zeros((2,))}

    ndraws = 3000  # number of draws from the distribution
    nburn = 1000   # number of "burn-in points" (which we'll discard)
    
    trace = pm.sample(ndraws,
		      tune=nburn, discard_tuned_samples=True)
    
    theta_list = []
    for tr in trace:
        theta_list.append(tr['theta'])
    
    theta = np.vstack(theta_list)
    np.save('theta', theta)
    probs = sb.backward(theta).eval()
    plot(probs)


Result looks like a simplex in \mathbb{R}^{3} projected onto the x-y plane.

sb_test