Why doesn't pymc.Multinomial support logit_p parameterization?

pm.Categorical supports the logit_p parameterization, which improves numerical stability by avoiding an explicit softmax when computing the log-likelihood (I presume it uses a stable log-sum-exp internally instead). The aggregated version, pm.Multinomial, currently does not support a similar logit_p argument.

Shouldn’t pm.Multinomial also benefit from the same numerical stability advantages if implemented using logits internally—computing the likelihood directly using a numerically stable log-sum-exp rather than explicitly performing a softmax and passing normalized probabilities (p)?

Or am I missing something that fundamentally prevents this numerical stability benefit in the multinomial case?

PyTensor probably already does the optimization you have in mind if you do the explicit softmax.

For the Bernoulli case using logit_p=x or doing p=inverse_logit(x) is exactly the same. Both cases are optimized numerically even if the user doesn’t know about logit_p. It’s there just for convenience. We can add the same convenience for Multinomial, I think there’s a GitHub issue for that

1 Like

Oh that’s awesome! I was wondering why I had not seen any issues at all using the softmax and passing in p. I keep forgetting that this is all symbolic. I really need to sit down and dig into pytensor and how it works. \text{Someday}^{tm}!

I find this terminology confusing because the arguments to pm.Categorical are not logit probs, they’re log probs. If you start with the probability simplex (vector of non-negative values that sum to 1), then apply logit to it, then apply softmax, you don’t get the original probabilities. If you start with the probability simplex then apply log to it, then apply softmax, you get back to where you started. So what are being called logics by everyone everywhere are not really logits. I have no idea where the notation that they were logits comes from, but it’s so strong, I misnamed Stan’s categorical_logit function—it should be categorical_log. I explain more fully in this blog post: https://statmodeling.stat.columbia.edu/2024/12/26/those-are-unnormalized-log-probabilities-not-logits-in-your-neural-networks-final-layer/

If you accept log probs as an argument, then you don’t need to do softmax or log-sum-exp, but if you need to do error checking, it’s that the arguments have a log-sum_exp of 0. If you accept unnormalized log probs as input, the operation you actually need for the categorical and multinomial is log(softmax(x)) = x - log_sum_exp(x).

1 Like