Fitting Reinforcement Learning model

I am trying to adapt the reinforcement learning model example (Fitting a Reinforcement Learning Model to Behavioral Data with PyMC — PyMC example gallery) to a more complex model.
I am quite new to PyMC and I am having some problems writing a function to use with scan.

The task I am modelling is a 2 armed bandit task with independent rewards and punishments for each arm.
The model is a single learning rate Rescorla-Wagner type of model with a softmax decision rule parametrised by an inverse temperature.
At each trial the model chooses one of the two arms and it receives one of 4 different outcomes (reward, punishment, reward and punishment, nothing).
The model tracks 4 different beliefs (arm A reward, arm A punishment, arm B reward, arm B punishment) but only 2 of them are updated at each trial based on which arm is chosen.

import numpy as np
import jax.numpy as jnp
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
import pymc as pm
import pytensor.tensor as pt
from pytensor import scan, function
import scipy

def simulate_model(params, beliefs, outcomes):
     - learning rate
     - inverse temperature

    beliefs = [belief_A_rew, belief_A_pun, belief_B_rew, belief_B_pun]
    outcomes = [reward_A, punish_A, reward_B, punish_B]  and are [0/1]
    learning_rate = params[0]
    inv_temp = params[1]
    n_trials = outcomes.shape[0]
    choices = np.zeros(n_trials)

    for t in range(n_trials):
        probA = np.exp(inv_temp * (beliefs[t, 0] - beliefs[t,1])) / (np.exp(inv_temp * (beliefs[t,0] - beliefs[t,1])) + np.exp(inv_temp* (beliefs[t,2] - beliefs[t,3]))) 

        choices[t] = int((probA > np.random.rand()) == 0)

        # choice is A
        if choices[t] == 0:
            beliefs[t+1,0] = beliefs[t,0] + learning_rate * (outcomes[t,0] - beliefs[t,0])
            beliefs[t+1,1] = beliefs[t,1] + learning_rate * (outcomes[t,1] - beliefs[t,1])
            beliefs[t+1,2] = beliefs[t,2] # don't update when not chosen
            beliefs[t+1,3] = beliefs[t,3] # don't update when not chosen
            beliefs[t+1,0] = beliefs[t,0] # don't update when not chosen
            beliefs[t+1,1] = beliefs[t,1] # don't update when not chosen
            beliefs[t+1,2] = beliefs[t,2] + learning_rate * (outcomes[t,2] - beliefs[t,2])
            beliefs[t+1,3] = beliefs[t,3] + learning_rate * (outcomes[t,3] - beliefs[t,3])

    return choices, beliefs

I am trying to follow the tutorial and I am trying to write the function to estimate the parameters using PyMC to then compare with the maximum likelihood results.

def update_model(choice, outcome, choice_probability, belief, learning_rate, inv_temp):
    if choice == 0:
        c = 0
        lr = pt.as_tensor([learning_rate, learning_rate, 0, 0])
        c = 2
        lr = pt.as_tensor([0, 0, learning_rate, learning_rate])

    choice_probability = np.exp(inv_temp * (belief[c] - belief[c+1])) / (np.exp(inv_temp * (belief[0] - belief[1])) + np.exp(inv_temp* (belief[2] - belief[3]))) 

    new_belief = belief + lr * (outcome - belief)

    return choice_probability, new_belief

# choices is [1 x n_trials]
# outcomes is [4 x n_trials]
choices_pt = pt.as_tensor_variable(choices, dtype='int32')
outcomes_pt = pt.as_tensor_variable(outcomes, dtype='int32')

learning_rate = pt.scalar('learning_rate')
inv_temp = pt.scalar('inv_temp')

choice_probablities = pt.zeros(1, dtype='float64')
beliefs = pt.zeros((1,4), dtype='float64')

results, updates = scan(
    sequences=[choices_pt, outcomes_pt], 
    non_sequences=[learning_rate, inv_temp],
    outputs_info=[choice_probablities, beliefs]

neg_loglike = -pt.sum(results[0])

pytensor_llik_td = pt.pytensor.function(
    inputs=[learning_rate, inv_temp], outputs=neg_loglike, on_unused_input="ignore"
result = pytensor_llik_td(0.2, 2)

The code works ok until I try to create the pytensor.function where I get the following error

TypeError: Inconsistency in the inner graph of scan ‘scan_fn’ : an input and an output are associated with the same recurrent state and should have compatible types but have type ‘Vector(float64, shape=(1,))’ and ‘Vector(float64, shape=(?,))’ respectively.

When I print the shapes of the tensors inside update_model I get

<Scalar(int32, shape=())>             # choice
<Vector(int32, shape=(?,))>           # outcome
<Vector(float64, shape=(1,))>         # choice_probability
<Matrix(float64, shape=(1, 4))>       #  belief

What is the correct way to pass a single row of outcome to scan at each iteration?
Any help is greatly appreciated! :slight_smile: