Recovering variance from multinomial softmax models

We have a multinomial model with very different size classes, and looking into it we discovered that the variances of the effect variables do not depend on the number of samples from each class at all.

In trying to understand what is going on, I boiled it down to this very simple toy example:

import pymc as pm, scipy as sp, numpy as np

vary_variances, shared_offset = False, False

N0, N1 = 1000, 10
p0, p1 = [0.2, 0.3, 0.5], [0.5, 0.3, 0.2]

inp = np.concatenate([ np.zeros(N0), np.ones(N1) ]).astype('int')
outp = np.concatenate([ sp.stats.multinomial.rvs(1,p0,size=N0), 
                        sp.stats.multinomial.rvs(1,p1,size=N1) ])

with pm.Model() as model:
    w = pm.ZeroSumNormal('w',n_zerosum_axes=1,shape=(2,3))
    s = np.ones(3)
    
    if vary_variances: s *=pm.LogNormal('s',0,1,shape=(3))

    ws = pm.Deterministic('ws',w*s[None,:])
    ws -= ws.mean(axis=1)[:,None]
    pm.Deterministic('ws2',ws)

    if shared_offset: ws += pm.ZeroSumNormal('ws_offset',n_zerosum_axes=1,shape=(3))[None,:]

    p = pm.math.softmax(ws,axis=1)

    pm.Multinomial('res',n=outp.sum(axis=1),p=p[inp],observed=outp)

    idata = pm.sample()

idata.posterior.ws.std(['chain','draw'])

(it is a multinomial softmax model with two input and three output categories, with N0 = 1000 samples from the first input category and only N1=10 samples from the second)

With both vary_variances and shared_offset False, everything works as expected

xarray.DataArray'ws'ws_dim_0: 2ws_dim_1: 3

array([[0.05339899, 0.04879904, 0.04217254],
       [0.42457227, 0.47800792, 0.55826867]])

i.e. the stdev of the estimates for first category is around 10 times smaller, which follows the square root law as expected.

Now, enabling vary_variances messes it up (somewhat unsurprisingly), but it can be relatively easily fixed with just subtracting out row means (which will have no effect in softmax result).

My problem is what happens with both vary_variances=True and shared_offset=True, which just adds a fixed effect irrespective of input category. In that case, even ws2 comes out almost even:

xarray.DataArray'ws2'ws2_dim_0: 2ws2_dim_1: 3

array([[0.46666511, 0.44031758, 0.42308026],
       [0.52039793, 0.52217466, 0.4740543 ]])

This seems like a simple and standard enough model that I’m hoping there is a standard solution to this that I am just not aware of? Or am I just missing something?

Hard to reason about all the permutations, but sounds like you over-parametrize when you do both. How do the pair plots of the unconstrained parameters look like?

You can keep them by passing idata_kwargs={"include_transformed": True} to pm.sample (IIRC)

Not quite sure what I am loooking at with this picture but the diagonal 3 steps down does look very different from the rest so it is probably meaningful.

I would plot only the unconstrained parameters, you’re seeing the correlation between S and log(S), which is not meaningful

And compare that with the well behaving model

I still don’t quite get what you are looking for I think. Because
az.plot_pair(idata,kind='kde',var_names=['w_zerosum__'])
comes out pretty similar other than scales

Well-behaving (False, False)

vs

Seemingly overdispersed (True, True)

Also - when you say hard to reason about all the permutations, what do you mean?
The model itself is a rather straightforward multinomial glm, just one with and the other without intercept.

It seems False-True has something rather interesting however:

w[0] and ws_offset are almost 100% correllated. Which is not very surprising as what we are observing 99% of the time is w[0]+ws_offset (with the other 1% being w[1]+ws_offset), so yes, there are two extra degrees of freedom there.

But - again, this to me looks like the simplest most standard multinomial glm you can make. So there must be a standard way of avoiding this happening, right?

Does the model behave with shared_offset but not vary_variances?

Nope, either set to true throws the variances for the groups to be a lot more similar than they should be.

What if you have more than 2 groups? Sounds like with just 2, the intercept and the scales can find a combination that fits the data without requiring much from the individual weights.

@ricardoV94 I tried with 3 groups and the situation did not look that different. And as I had other more urgent things to do I had to put this aside for a week or so.

Probably good that I did. Coming back fresh to it, I think I can now make sense of it all.

The key thing is that what we should expect to be roughly the same for all these models is the variance of -centered- ws just as it is going into the softmax (as softmax does not care if a constant is added to all the values).

This much was obvious, hence my attempts at centering between ws and ws2. But I forgot that w and ws_offset can come out strongly corellated to each other (which is what pair plots show and I guess @ricardoV94 was trying to point me to) and that will also change final variances because V(X+Y) = V(X) + V(Y) + 2 Cov(X,Y)

So to answer my own question:

Because intercept can covary with the dependent variable effects, if you care about variances, you should consider them together i.e. look at the variance of (intercept + effect) as that factors in their covariance.

I’m guessing same logic should work for multiple dependent variables, i.e. if the model for log-likelihood of result category c is y_kc = i_c + v_kc + w_kc (where v_kc and w_kc are the effects of first and second independent variable on k-th respondent for result category c) then meaningful variances can be read off (i_c + v_kc) and (i_c + w_kc)

When I have some time, I’ll try to implement this logic for my bigger models and see if it gives me what I’m expecting.

1 Like