Memory problems with nested theano.scan

My problem: Running the model on a couple very small datasets works fine, but running it on a larger number of larger datasets leads to memory errors.

The model is a reinforcement learning model. It takes in sequences of stimuli, actions, and rewards, and then calculates action probabilities based on some free parameters. It works fine when run on 2 participants with 10 trials (like below), but it get very very slowly and eventually crashes when run on 20 participants with 100 trials - i.e. still a pretty small dataset.

A few thoughts that might be related to the issue:

Here’s my model:

n_subj, n_trials = 2, 10
seasons = np.random.choice(range(n_seasons), size=[n_trials, n_subj])  # np.ones([n_trials, n_subj], dtype=int)
aliens = np.random.choice(range(n_aliens), size=[n_trials, n_subj])  # np.ones([n_trials, n_subj], dtype=int)
actions = np.random.choice(range(n_actions), size=[n_trials, n_subj])
rewards = 10 * np.random.rand(n_trials * n_subj).reshape([n_trials, n_subj]).round(2)

with pm.Model() as model:

    # Sample subject parameters
    alpha = pm.Uniform('alpha', lower=0, upper=1, shape=n_subj, testval=np.random.choice([0.1, 0.5], n_subj))
    beta = pm.HalfNormal('beta', sd=5, shape=n_subj, testval=5 * np.random.rand(n_subj).round(2))

    # Initialize Q-values
    Q_low0 = alien_initial_Q * T.ones([n_subj, n_TS, n_aliens, n_actions])
    Q_high0 = alien_initial_Q * T.ones([n_subj, n_seasons, n_TS])

    # Define function to update Q-values based on stimulus, action, and reward
    def update_Q_low(season, alien, action, reward, Q_low, alpha):
        # Loop over trials: take data for all subjects, 1 trial
        Q_low_new = Q_low.copy()
        RPE_low = alpha * (reward - Q_low_new[T.arange(n_subj), season, alien, action])
        Q_low_new = T.set_subtensor(Q_low_new[T.arange(n_subj), season, alien, action],
                                Q_low_new[T.arange(n_subj), season, alien, action] + RPE_low)
        return Q_low_new

    # Get Q-values for all trials
    Q_low, _ = theano.scan(fn=update_Q_low,
                       sequences=[seasons, aliens, actions, rewards],
                       outputs_info=[Q_low0],
                       non_sequences=[alpha])

    Q_low = T.concatenate([[Q_low0], Q_low[:-1]], axis=0)  # Add first trial Q-values, remove last trial Q-values

    # Define function to transform Q-values into action probabilities
    def softmax(Q_low, season, alien, beta):
        # Loop over subjects within 1 trial
        Q_low_stim = Q_low[season, alien]
        Q_low_exp = T.exp(beta * Q_low_stim)
        p = Q_low_exp / T.sum(Q_low_exp)
        return p

    def softmax_trial_wrapper(Q_low, season, alien, beta):
        # Loop over trials
        p, _ = theano.scan(fn=softmax,
                           sequences=[Q_low, season, alien, beta])
        return p

    # Transform Q-values into action probabilities for all subj, all trials
    p, _ = theano.scan(fn=softmax_trial_wrapper,
                       sequences=[Q_low, seasons, aliens],
                       non_sequences=[beta])

    action_wise_p = p.flatten().reshape([n_trials * n_subj, n_actions])

    # Select actions
    actions = pm.Categorical('actions', p=action_wise_p, observed=actions.flatten())

    # Sample model
    trace = pm.sample(n_samples, tune=n_tune, chains=n_chains, cores=n_cores)

I am using:

  • PyMC3 Version: 3.4.1
  • Theano Version: 1.0.2
  • Python Version: 3.6.6
  • Operating system: Win 10 Pro
  • How did you install PyMC3: conda

Update: I am now pretty sure that the issue is caused by the nested theano.scan loops. When I remove them, the model runs fine for large datasets.

How can I get rid of the nested loops? The reason why I am using them is to transform Q-values into probabilities, i.e., to transform triples of action values into triples of action probabilities. The Q-values are anywhere between 0-10, and I apply a softmax transform and then normalize to get probabilities that sum up to 1. The reason why I am looping twice is to get the normalization right. I need individual triples (the 3 Q-values of 1 trial from 1 subject), rather than lists of lists of triples (Q-values of all trials from all subjects) to calculate the right normalization term. Is there a more elegant way to do this using matrix operations?

Again, here is the code in question:

# Define function to transform Q-values into action probabilities
    def softmax(Q_low, season, alien, beta):
        # Loop over subjects within 1 trial
        Q_low_stim = Q_low[season, alien]
        Q_low_exp = T.exp(beta * Q_low_stim)
        p = Q_low_exp / T.sum(Q_low_exp)
        return p

    def softmax_trial_wrapper(Q_low, season, alien, beta):
        # Loop over trials
        p, _ = theano.scan(fn=softmax,
                           sequences=[Q_low, season, alien, beta])
        return p

    # Transform Q-values into action probabilities for all subj, all trials
    p, _ = theano.scan(fn=softmax_trial_wrapper,
                       sequences=[Q_low, seasons, aliens],
                       non_sequences=[beta])