We have a multinomial model with very different size classes, and looking into it we discovered that the variances of the effect variables do not depend on the number of samples from each class at all.
In trying to understand what is going on, I boiled it down to this very simple toy example:
import pymc as pm, scipy as sp, numpy as np
vary_variances, shared_offset = False, False
N0, N1 = 1000, 10
p0, p1 = [0.2, 0.3, 0.5], [0.5, 0.3, 0.2]
inp = np.concatenate([ np.zeros(N0), np.ones(N1) ]).astype('int')
outp = np.concatenate([ sp.stats.multinomial.rvs(1,p0,size=N0),
sp.stats.multinomial.rvs(1,p1,size=N1) ])
with pm.Model() as model:
w = pm.ZeroSumNormal('w',n_zerosum_axes=1,shape=(2,3))
s = np.ones(3)
if vary_variances: s *=pm.LogNormal('s',0,1,shape=(3))
ws = pm.Deterministic('ws',w*s[None,:])
ws -= ws.mean(axis=1)[:,None]
pm.Deterministic('ws2',ws)
if shared_offset: ws += pm.ZeroSumNormal('ws_offset',n_zerosum_axes=1,shape=(3))[None,:]
p = pm.math.softmax(ws,axis=1)
pm.Multinomial('res',n=outp.sum(axis=1),p=p[inp],observed=outp)
idata = pm.sample()
idata.posterior.ws.std(['chain','draw'])
(it is a multinomial softmax model with two input and three output categories, with N0 = 1000 samples from the first input category and only N1=10 samples from the second)
With both vary_variances and shared_offset False, everything works as expected
xarray.DataArray'ws'ws_dim_0: 2ws_dim_1: 3
array([[0.05339899, 0.04879904, 0.04217254],
[0.42457227, 0.47800792, 0.55826867]])
i.e. the stdev of the estimates for first category is around 10 times smaller, which follows the square root law as expected.
Now, enabling vary_variances messes it up (somewhat unsurprisingly), but it can be relatively easily fixed with just subtracting out row means (which will have no effect in softmax result).
My problem is what happens with both vary_variances=True and shared_offset=True, which just adds a fixed effect irrespective of input category. In that case, even ws2 comes out almost even:
xarray.DataArray'ws2'ws2_dim_0: 2ws2_dim_1: 3
array([[0.46666511, 0.44031758, 0.42308026],
[0.52039793, 0.52217466, 0.4740543 ]])
This seems like a simple and standard enough model that I’m hoping there is a standard solution to this that I am just not aware of? Or am I just missing something?