When I look at the source code of pymc3, I see that the logp
of Dirichlet
would be -np.inf
if the random variable is out of bounds of the probability simplex. But -np.inf
is not differentiable. How does Dirichlet deal with the out of bound situation?
How would the integrator and mass matrix adaptation work with a logp that is -np.inf
?
The Dirichlet also use a “stick breaking” transformation. Does that mean reparametrizing the Dirichlet as independent Beta distributions that are bounded in the interval (0, 1)?
While the bound
condition in the source code seems to constraint the value
to be in the interval (0, 1), the comment states the bound
constrains the sum of values to be 1. Would an additional re-parametrization to logits transform the (0, 1) constraint to a (-\infty, \infty) constraint?
Thanks again.
Cut & paste from pymc3 source code:
class Continuous(Distribution):
def __init__(self, shape=(), dtype=None, defaults=('median', 'mean', 'mode'),
*args, **kwargs):
if dtype is None:
dtype = theano.config.floatX
super().__init__(shape, dtype, defaults=defaults, *args, **kwargs)
....
class Dirichlet(Continuous):
....
def __init__(self, a, transform=transforms.stick_breaking,
*args, **kwargs):
shape = np.atleast_1d(a.shape)[-1]
kwargs.setdefault("shape", shape)
super().__init__(transform=transform, *args, **kwargs)
...
def logp(self, value):
k = self.k
a = self.a
# only defined for sum(value) == 1
return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1)
+ gammaln(tt.sum(a, axis=-1)),
tt.all(value >= 0), tt.all(value <= 1),
k > 1, tt.all(a > 0), broadcast_conditions=False)
def bound(logp, *conditions, **kwargs):
broadcast_conditions = kwargs.get('broadcast_conditions', True)
if broadcast_conditions:
alltrue = alltrue_elemwise
else:
alltrue = alltrue_scalar
return tt.switch(alltrue(conditions), logp, -np.inf)