Gumbel-Softmax version of Bernoulli and Categorical distributions

Does pymc have gumbel-softmax reparametrization of Bernoulli and Categorical?
Does that reparametrization actually improve fitting of Bernoulli and Categorical?
Is it possible to use a random variable for the temperature?

Thanks again.

1 Like

I’m not familiar with these parametrizations – do you have a reference?

In general, a different parametrization may change the fitting of a particular model, but the details will matter in exactly how it changes. symbolic-pymc is an interesting project to do things like this automatically, and Hoffman, Johnson, and Tran had a paper on this last year.

2016 paper at: https://arxiv.org/abs/1611.01144

tfp has a ExpRelaxedOneHotCategorical but the temperature must be a float. The Exp part of the name refers to another reparametrization that avoids underflow in the logp.

As the temperature \tau approaches 0, the gumbel softmax version becomes a categorical distribution.


I don’t get why their probability density is well-defined because some y_{i} can be 0. Then I get division by 0 in the \pi_{i} / y_{i}^{\tau+1} terms.


For a categorical variable Y with k = 3 classes, the 3 - 1 dimensional simplex is an ordinary triangle.

Let \vec{y} \in \Delta^{k} be one-hot encoding of the categorical variable with k classes, where \Delta^{k-1} is a (k - 1) dimensional probability simplex.
Let \pi_{i} = P(Y = i) be the probability that the categorical random variable have a value of i.

P_{\pi, \tau}(\vec{y}) = \Gamma(k) \tau^{k-1} \left( \sum_{i=1}^{k} \pi_{i} / y_{i}^{\tau} \right)\prod_{i=1}^{k} (\pi_{i} / y_{i}^{\tau+1})

I guess \Gamma(k) means factorial of k when k is an integer.

1 Like

torch has a gumbel_softmax too

https://pytorch.org/docs/stable/_modules/torch/nn/functional.html

They have an epsilon. Maybe they clamp the y_i to avoid division by zero?

Why the paper says a categorical distribution is non-differentiable?
I was thinking P(\vec{\pi}) = \sum_{i = 1}^{k} y_{i} \pi_{i}.