Divergence encountered when fitting RL model

I’m fitting RL models for a cognitive science study I am currently running. However, I have been consistently getting lots of divergences since 15% on the progress bar. The trace plots are almost flat and there are divergences everywhere. I also can’t find bugs in my code. I’m out of ideas. Below I’ll provide a minimal example using simulated data that replicated the issue. I’m forever grateful if someone can give me some insights.

Experiment Overview:
In my experiment, each subject is asked to perform 1 out of 3 candidate actions (encoded as 0,1, and 2) as soon as seeing a stimulus on the screen. Each action may give a reward amount of 0,1, or 2. Each subject would play two independent blocks of this learning game. In each session, 2 different stimuli may show up sequentially in an interleaved fashion. Each session has in total of 24 trials, that is 12 trials per stimulus. Importantly, each stimulus is uniquely paired with another experimental condition: split. My RL model instantiates the hypothesis that this “split” variable modifies the subject’s mental representation of rewards during learning.

Example Data:
The example data constitutes of 5 NumPy arrays: action,reward,split,trial,stim. Each array has shape nSubj=2, nBlock=2, (nTrial=12)*(nStim=2)=24. action encodes which action was taken, reward encodes the reward generated by the action, split encodes the split condition, trial encodes the number of repetition of the current stimulus. stim encodes which stimulus (coded as 0 or 1) was displayed.

action = np.array([[[2, 2, 0, 0, 1, 2, 2, 2, 0, 0, 1, 1, 2, 2, 1, 2, 2, 1, 2, 0, 2,
                     1, 1, 1],
                    [2, 2, 2, 0, 1, 2, 1, 0, 1, 0, 2, 2, 1, 0, 0, 0, 2, 0, 1, 1, 0,
                     0, 0, 0]],

                   [[1, 2, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 2, 1,
                     1, 0, 1],
                    [2, 0, 0, 0, 2, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 2, 0, 0, 1, 0,
                       2, 0, 1]]])

reward = np.array([[[1., 1., 2., 0., 2., 1., 1., 1., 0., 0., 0., 2., 1., 1., 0., 1.,
                  1., 0., 1., 2., 1., 2., 2., 2.],
                    [1., 1., 1., 2., 0., 1., 2., 0., 2., 2., 1., 1., 0., 2., 0., 2.,
                  1., 0., 2., 0., 0., 0., 0., 2.]],

                   [[0., 1., 0., 1., 0., 2., 1., 2., 2., 0., 2., 2., 2., 2., 2., 2.,
                     2., 2., 2., 1., 2., 2., 2., 2.],
                    [1., 2., 2., 2., 1., 2., 2., 2., 2., 2., 2., 2., 2., 0., 2., 2.,
                       1., 2., 2., 2., 2., 1., 0., 2.]]])

split = np.array([[[0.1, 0.3, 0.1, 0.3, 0.3, 0.1, 0.1, 0.3, 0.3, 0.3, 0.1, 0.3,
                 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.1, 0.3, 0.3, 0.3, 0.3],
                   [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.3, 0.3, 0.1, 0.3, 0.3,
                 0.1, 0.1, 0.3, 0.1, 0.3, 0.3, 0.3, 0.1, 0.3, 0.3, 0.3, 0.1]],

                  [[1., 1., 0.6, 0.6, 0.6, 1., 0.6, 1., 1., 0.6, 1., 1.,
                    1., 0.6, 0.6, 0.6, 1., 1., 1., 0.6, 0.6, 0.6, 1., 0.6],
                   [0.6, 1., 1., 1., 0.6, 0.6, 1., 1., 1., 0.6, 1., 0.6,
                      0.6, 1., 1., 0.6, 0.6, 1., 1., 0.6, 1., 0.6, 0.6, 0.6]]])

trial = np.array([[[1,  1,  2,  2,  3,  3,  4,  4,  5,  6,  5,  7,  6,  7,  8,  9,
                    10, 11,  8, 12,  9, 10, 11, 12],
                   [1,  2,  3,  4,  5,  6,  1,  2,  3,  7,  4,  5,  8,  9,  6, 10,
                    7,  8,  9, 11, 10, 11, 12, 12]],

                  [[1,  2,  1,  2,  3,  3,  4,  4,  5,  5,  6,  7,  8,  6,  7,  8,
                    9, 10, 11,  9, 10, 11, 12, 12],
                   [1,  1,  2,  3,  2,  3,  4,  5,  6,  4,  7,  5,  6,  8,  9,  7,
                      8, 10, 11,  9, 12, 10, 11, 12]]])

stim = np.array([[[0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1,
                   1, 1, 1],
                  [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,
                   1, 1, 0]],

                 [[0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1,
                   1, 0, 1],
                  [1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0,
                     1, 1, 1]]])

My RL Model:

def update_Q(action, stim, reward, Qs, Q0, alpha, decay):
    # adapted from Ricardo's tutorial
    idx_noNaN_b, idx_noNaN_s = at.nonzero(~at.isnan(reward))
    stim, action, reward = stim[idx_noNaN_b, idx_noNaN_s], action[idx_noNaN_b,
                                                                    idx_noNaN_s], reward[idx_noNaN_b, idx_noNaN_s]
    Qs = at.set_subtensor(Qs[idx_noNaN_s, idx_noNaN_b, stim, action], Qs[idx_noNaN_s, idx_noNaN_b, stim, action] +
                            alpha[idx_noNaN_s] * (reward - Qs[idx_noNaN_s, idx_noNaN_b, stim, action]))
    return Qs + decay[:, None, None, None]*(Q0-Qs) #the forgetting mechanism
def aesara_llik_td(data, alpha, beta, decay, gamma):
    action,reward,split,trial,stim = data
    num_subj, num_block, _ = action.shape
    num_stim=2
    # each data tensor below have shape: (nTrial=12)*(nStim=2), nBlock=2, nSubj=2
    actions_ = at.as_tensor_variable(action, dtype='int32').T
    rewards_ = at.as_tensor_variable(reward, dtype='float64').T
    splits_ = at.as_tensor_variable(split, dtype='float64').T
    trials_ = at.as_tensor_variable(trial, dtype='int32').T
    stims_ = at.as_tensor_variable(stim, dtype='int32').T
    
    '''
    Modify reward value based on splits_
    for example, if reward = 1, then after modification reward = 1*(splits_+gamma(1-2*splits_)) where gamma is a model parameter
    if reward = 2, then after modification reward = 2*(splits_+gamma(1-2*splits_)) where gamma is a model parameter
    if reward = 0, then after modification reward = 0*(splits_+gamma(1-2*splits_)) = 0
    '''
    FS_splits_ = splits_ + gamma[None,None, :]*(1 - 2*splits_) 
    idx_noNaN_t, idx_noNaN_b, idx_noNaN_s = at.nonzero(~at.isnan(rewards_))
    rewards_ = at.set_subtensor(rewards_[idx_noNaN_t, idx_noNaN_b, idx_noNaN_s], rewards_[
        idx_noNaN_t, idx_noNaN_b, idx_noNaN_s] * FS_splits_[idx_noNaN_t, idx_noNaN_b, idx_noNaN_s])

    '''
    Before modification, we should initialize the Q tensor to be at.ones because 1 is the expected reward for random guessing.
    So here, after modification, we initialize the Q tensor to be 1*(splits_+gamma(1-2*splits_)) which is the variable FS_splits_
    Q tensor has shape n_subj=2, n_block=2, n_stimulus=2, n_actions=3
    '''
    Q0 = at.empty((num_subj, num_block, num_stim), dtype='float64') #first ignores the action dimention
    idx_t1 = at.nonzero(at.neq(trials_, 1)) #find tensor index for the first occurence of each stimulus
    stim_t1, FS_split_t1 = stims_[idx_t1[0], idx_t1[1], idx_t1[2]], FS_splits_[
        idx_t1[0], idx_t1[1], idx_t1[2]] # each stimulus is uniquely paired with a split number
    Q0 = at.set_subtensor(Q0[idx_t1[2], idx_t1[1], stim_t1], FS_split_t1) 
    Q0 = at.repeat(at.expand_dims(Q0, axis=3), 3, axis=3) #repeats the same initial Q value for all 3 candidate actions
    # Compute the Q values for each trial
    Qs = Q0.copy()
    Qs, _ = aesara.scan(
        fn=update_Q,
        sequences=[actions_, stims_, rewards_],
        outputs_info=[Qs],
        non_sequences=[Q0, alpha, decay])

    # Apply the softmax transformation. Below should be the same as typical RLs
    Qs = Qs[:-1] * beta[None, :, None, None, None]
    logp_actions = Qs - at.logsumexp(Qs, axis=4, keepdims=True)
    # Calculate the negative log likelihod of the observed actions
    rewards_, stims_, actions_ = rewards_[1:], stims_[1:], actions_[
        1:]  # remove data of first trial which is random choice
    idx_noNaN_t, idx_noNaN_b, idx_noNaN_s = at.nonzero(~at.isnan(rewards_))
    stim_idx_noNaN, action_idx_noNaN = stims_[
        idx_noNaN_t, idx_noNaN_b, idx_noNaN_s], actions_[idx_noNaN_t, idx_noNaN_b, idx_noNaN_s]
    logp_actions = logp_actions[idx_noNaN_t, idx_noNaN_s,
                                idx_noNaN_b, stim_idx_noNaN, action_idx_noNaN]
    neg_loglike = -at.sum(logp_actions)
    return -neg_loglike 
    
n_subj = 2
with pm.Model() as m:
    alpha = pm.Beta('alpha', alpha=1, beta=1, shape=n_subj)
    beta = pm.Gamma('beta', alpha=3, beta=1/2, shape=n_subj)
    decay = pm.Beta('decay', alpha=2, beta=15, shape=n_subj)
    gamma = pm.Normal('gamma', mu=0, sigma=1, shape=n_subj)
    like = pm.Potential('like', aesara_llik_td(data=[action,reward,split,trial,stim], alpha=alpha, beta=beta, decay=decay, gamma=gamma))
    tr = pm.sample(draws=3000,tune = 1000, chains=4, cores=4)

I think you’d only need these 3 packages and python 3.9:

import numpy as np
import pymc as pm #pymc4
import aesara.tensor as at

Here are somethings I tried:

  1. I fix gamma to be 0, then no divergences occur.
  2. I fix it to be 1.2 for example, then I start seeing some divergences (like 6, instead of over 1000). Meaning the gamma value matters somehow.
  3. I adjusted the sd of the normal prior for gamma to be larger or small, didn’t help.
  4. I adjusted the alpha and beta of the gamma distribution for the beta parameter (which is the inverse softmax temperature), didn’t help.
  5. I tried using gamma = pm.TruncatedNormal(‘gamma’, mu=0, sigma=10, lower=-2, upper=3, shape=n_subj), didn’t help.
  6. I tried fitting only with the first subject data, no divergence; then only with the second subject, again getting divergences.
  7. Oddly, if I change splits_ + gamma[None,None, :]*(1 - 2*splits_) into splits_ + gamma[None,None, :]*(1 - splits_) , no divergences occur. The only difference is the “2*”, why does it make such a huge difference? But the caveat is, I still get r hat and effective sample size warning, so maybe it doesn’t entirely fix it.

I think I fixed it but I don’t really understand why. Instead of modifying rewards before passing them into aesara.scan, I instead used split to modify the Q values afterward. At the same line as applying the beta temperature. Because the effect of split on reward is the same across trials, these two algorithms are equivalent. So I guess lesson one for debugging divergence is: try using different ways to write the same algorithm. However, I didn’t use the most up-to-date version of pymc (4.1.3), but 4.0.0. So maybe this divergence has something to do with a bug in 4.0.0 about aesara.scan. Another tip for debugging divergence is: centering. Which helped the divergence issue I encountered when fitting another RL model variant.