Handling multiple cross-cutting (non-nested) groups in hierarchical model


#1

I have a dyadic model in which I am using random intercepts for the sender and receiver in each dyad. As you might expect, these intercepts are not stable: you could always add c to all sender effects and subtract it from all receiver effects. This makes the model particularly difficult to estimate, and gives me concerns about convergence overall.
If I were doing a fixed effects model, I would just exclude dummies for one of the categories, or set the value to 0 for that category. I am not sure how to set one value to 0 in PyMC3, or if there’s another options that would work well. Does anyone have any ideas?

Here is a toy model with data that exemplifies the problem.

Simulating the data:

sn = 10
rn = 30
n = sn*rn
senders = np.random.normal(5,5,sn)
receivers = np.random.normal(10,7,rn)
sx = np.random.binomial(1,.3,sn)
rx = np.random.binomial(1,.5,rn)
x = np.random.normal(20,10,n)
ri = list(range(rn))*sn
si = []
for i in range(sn):
    si.extend([i]*rn)
x = pd.DataFrame({'x':x,'ri':ri,'si':si})
x['er']=np.random.normal(0,2,n)
y = []
for i,row in x.iterrows():
    y.append(row.x*5 +row.er+
             sx[int(row.si)]*7 + senders[int(row.si)]+
             rx[int(row.ri)]*(-2) + receivers[int(row.ri)])
x['y']=y

The model:

with pm.Model() as test1:
    mu_s = pm.Flat('sender_fe')
    sig_s = pm.HalfCauchy('sig_s', beta = 2.5)
    b_sx = pm.Flat('sx')
    raw_s = pm.Normal('raw_s',mu=0,sd=1,shape = sn)
    s = pm.Deterministic('s', 
                        mu_s+
                        (b_sx*sx)+
                         (sig_s * raw_s)
                        )
    
    mu_r = pm.Flat('receiver_fe')
    sig_r = pm.HalfCauchy('sig_r', beta = 2.5)
    b_rx = pm.Flat('rx')
    raw_r = pm.Normal('raw_r',mu=0,sd=1,shape = rn)
    r = pm.Deterministic('r', 
                        mu_r+
                        (b_rx*rx)+
                         (sig_r * raw_r)
                        )
    
    b_x = pm.Normal('b_x',mu=0,sd=100**2)
    hat = pm.Deterministic('hat',
                            (b_x*x.x)+
                            s[x.si.values]+
                            r[x.ri.values]
                           )
    sig = pm.HalfCauchy('sig',beta=2.5)
    y = pm.Normal('y',mu=hat,sd=sig, observed = x.y)

#2

When your model is over-parameterized like this, reducing the number of parameter is a good and straightforward way. You can try setting some mu_* to zero, for example, just do mu_r = 0

Also, you should avoid using pm.Flat as prior, a Normal distribution would be much better.


#3

Thanks for the reply.
Can you explain what why for the flat prior? I’m trying to use the asymmetric sampling trick since hierarchical models are so hard to sample. When I run a simulation without cross cutting effects the flats don’t present a problem because of the term where I multiply raw*sigma. I recover my parameter values and the NUTS is less likely to get stuck and never complete when I run it this way.

Now I feel really silly. Setting mu_r=0 works. I recover my coefficients, and the model runs much faster, and in retrospect it is very obvious. It adds the mean of the receiver_fe to the sender_fe, but I don’t care about those values and there aren’t enough degrees of freedom to distinguish between individual sender and receiver effects if they’re unknown anyway. I was making it way more complicated in my head.

Thank you!

I’ll post the working simulation later today for posterity. It is obvious in retrospect, but maybe it’ll help someone else who got stuck in a similar logic loop.


#4

I’ll try to explain one of the reasons by comparing to what many machine learning algorithms call regularization. There are many situations in ML where there are a LOT of parameters. To prevent overfitting, and to help find the most relevant parameters and automatically discard the irrelevant ones, a regularization term is added to the training loss function. Some examples are the L1 and L2 regularizations that basically penalize big parameter values. This automatically let’s the fitting procedure move the irrelevant parameters to zero and only change them if they are truly important during training.

These added regularizations are the same as adding a Laplace or normal prior distribution instead of a flat prior. They automatically help in the regularization. The main difference between regularization and noon flat priors in my opinion is that in Bayesian inference you usually don’t look for the single most likely parameters (MAP) but sample from the full posterior distribution, so the results are different to standard ML.


#5

This makes a lot of sense. Sounds like it will also help with convergence with a highly parameterized model rather than make it more complex, which is what I’d been afraid of.
This also suggests that if there are nearly colinear factors this should help keep them close to zero rather than splitting off to inf and -inf, which is part of the issue I was having.
Thanks!


#6

Following @junpenglao, here’s the corrected model. Receiver effects are not absolute but relative to the sender effects.

with pm.Model() as test1:
    mu_s = 0
    sig_s = pm.HalfCauchy('sig_s', beta = 2.5)
    b_sx = pm.Normal('sx',mu=0,sd=100)
    raw_s = pm.Normal('raw_s',mu=0,sd=1,shape = sn)
    s = pm.Deterministic('s', 
                        mu_s+
                        (b_sx*sx)+
                         (sig_s * raw_s)
                        )
    
    mu_r = pm.Normal('receiver_fe',mu=0,sd=1000)
    sig_r = pm.HalfCauchy('sig_r', beta = 2.5)
    b_rx = pm.Normal('rx',mu=0,sd=100)
    raw_r = pm.Normal('raw_r',mu=0,sd=1,shape = rn)
    r = pm.Deterministic('r', 
                        mu_r+
                        (b_rx*rx)+
                         (sig_r * raw_r)
                        )
    
    b_x = pm.Normal('b_x',mu=0,sd=100**2)
    hat = pm.Deterministic('hat',
                            (b_x*x.x)+
                            s[x.si.values]+
                            r[x.ri.values]
                           )
    sig = pm.HalfCauchy('sig',beta=2.5)
    y = pm.Normal('y',mu=hat,sd=sig, observed = x.y)