Fitting a simple Reinforcement Learning model to behavioral data with PyMC3 (Jupyter NB)


I recently wrote a Jupyter Notebook illustrating how one can fit a very simple RL model to simulated behavioral data. I thought this could be useful for others, specially when writing models that require recursive calculations based on sampled parameters:


Have a nice day :slight_smile:


Thank you for sharing, this looks great :tada:


FYI: @Maria also share her RL models on human data here Modeling reinforcement learning of human participant using PyMC3


I’m loving this sharing mood :wink:


Is there a version with pymc 4 with aesara :pray:? I try to use your code and I get an Input dimension mismatch error.

def update_Q(action, reward,
    This function updates the Q table according to the RL update rule.
    It will be called by theano.scan to do so recursevely, given the observed data and the alpha parameter
    This could have been replaced be the following lamba expression in the theano.scan fn argument:
        fn=lamba action, reward, Qs, alpha: tt.set_subtensor(Qs[action], Qs[action] + alpha * (reward - Qs[action]))

    Qs = at.set_subtensor(Qs[action], Qs[action] + alpha * (reward - Qs[action]))
    return Qs

def theano_llik_td(alpha, beta, actions, rewards):
    rewards = aesara.shared(np.asarray(rewards, dtype='int16'))
    actions = aesara.shared(np.asarray(actions, dtype='int16'))

    # Compute the Qs values
    Qs = 0.5 * at.ones((2), dtype='float64')
    Qs, updates = aesara.scan(
        sequences=[actions, rewards],

    # Apply the sotfmax transformation
    Qs_ = Qs[:-1] * beta
    log_prob_actions = Qs_ - pm.math.logsumexp(Qs_, axis=1)

    # Calculate the negative log likelihod of the observed actions
    log_prob_actions = log_prob_actions[at.arange(actions.shape[0]-1), actions[1:]]
    return at.sum(log_prob_actions)  # PyMC makes it negative by default

with pm.Model() as m:
    alpha = pm.Beta('alpha', 1, 1)
    beta = pm.HalfNormal('beta', 10)
    like = pm.Potential('like', theano_llik_td(alpha, beta, actions, rewards))
    tr = pm.sample()
ValueError                                Traceback (most recent call last)
File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/, in Function.__call__(self, *args, **kwargs)
    960 try:
    961     outputs = (
--> 962         self.fn()
    963         if output_subset is None
    964         else self.fn(output_subset=output_subset)
    965     )
    966 except Exception:

ValueError: Input dimension mismatch. One other input has shape[1] = 2, but input[1].shape[1] = 99.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Input In [12], in <cell line: 7>()
      9 beta = pm.HalfNormal('beta', 10)
     10 like = pm.Potential('like', theano_llik_td(alpha, beta, actions, rewards))
---> 11 tr = pm.sample()

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    531         [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
    532"Auto-assigning NUTS sampler...")
--> 533     initial_points, step = init_nuts(
    534         init=init,
    535         chains=chains,
    536         n_init=n_init,
    537         model=model,
    538         random_seed=random_seed_list,
    539         progressbar=progressbar,
    540         jitter_max_retries=jitter_max_retries,
    541         tune=tune,
    542         initvals=initvals,
    543         **kwargs,
    544     )
    546 if initial_points is None:
    547     # Time to draw/evaluate numeric start points for each chain.
    548     ipfns = make_initial_point_fns_per_chain(
    549         model=model,
    550         overrides=initvals,
    551         jitter_rvs=filter_rvs_to_jitter(step),
    552         chains=chains,
    553     )

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, **kwargs)
   2480"Initializing NUTS using {init}...")
   2482 cb = [
   2483     pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
   2484     pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
   2485 ]
-> 2487 initial_points = _init_jitter(
   2488     model,
   2489     initvals,
   2490     seeds=random_seed_list,
   2491     jitter="jitter" in init,
   2492     jitter_max_retries=jitter_max_retries,
   2493 )
   2495 apoints = [ for point in initial_points]
   2496 apoints_data = [ for apoint in apoints]

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries)
   2379 if i < jitter_max_retries:
   2380     try:
-> 2381         model.check_start_vals(point)
   2382     except SamplingError:
   2383         # Retry with a new seed
   2384         seed = rng.randint(2**30, dtype=np.int64)

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/, in Model.check_start_vals(self, start)
   1716     valid_keys = ", ".join(self.named_vars.keys())
   1717     raise KeyError(
   1718         "Some start parameters do not appear in the model!\n"
   1719         f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
   1720     )
-> 1722 initial_eval = self.point_logps(point=elem)
   1724 if not all(np.isfinite(v) for v in initial_eval.values()):
   1725     raise SamplingError(
   1726         "Initial evaluation of model at starting point failed!\n"
   1727         f"Starting values:\n{elem}\n\n"
   1728         f"Initial evaluation results:\n{initial_eval}"
   1729     )

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/, in Model.point_logps(self, point, round_vals)
   1757 factors = self.basic_RVs + self.potentials
   1758 factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
   1759 return {
   1760 np.round(np.asarray(factor_logp), round_vals)
   1761     for factor, factor_logp in zip(
   1762         factors,
-> 1763         self.compile_fn(factor_logps_fn)(point),
   1764     )
   1765 }

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/, in PointFunc.__call__(self, state)
   1861 def __call__(self, state):
-> 1862     return self.f(**state)

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/, in Function.__call__(self, *args, **kwargs)
    973     if hasattr(self.fn, "thunks"):
    974         thunk = self.fn.thunks[self.fn.position_of_error]
--> 975     raise_with_op(
    976         self.maker.fgraph,
    977         node=self.fn.nodes[self.fn.position_of_error],
    978         thunk=thunk,
    979         storage_map=getattr(self.fn, "storage_map", None),
    980     )
    981 else:
    982     # old-style linkers raise their own exceptions
    983     raise

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    533     warnings.warn(
    534         f"{exc_type} error does not allow us to add an extra error message"
    535     )
    536     # Some exception need extra parameter in inputs. So forget the
    537     # extra long error message in that case.
--> 538 raise exc_value.with_traceback(exc_trace)

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/, in Function.__call__(self, *args, **kwargs)
    959 t0_fn = time.time()
    960 try:
    961     outputs = (
--> 962         self.fn()
    963         if output_subset is None
    964         else self.fn(output_subset=output_subset)
    965     )
    966 except Exception:
    967     restore_defaults()

ValueError: Input dimension mismatch. One other input has shape[1] = 2, but input[1].shape[1] = 99.
Apply node that caused the error: Elemwise{Sub}[(0, 0)](Elemwise{Mul}[(0, 0)].0, Elemwise{Composite{(i0 + log(i1))}}[(0, 0)].0)
Toposort index: 68
Inputs types: [TensorType(float64, (None, None)), TensorType(float64, (1, None))]
Inputs shapes: [(99, 2), (1, 99)]
Inputs strides: [(16, 8), (792, 8)]
Inputs values: ['not shown', 'not shown']
Outputs clients: [[AdvancedSubtensor(Elemwise{Sub}[(0, 0)].0, ARange{dtype='int64'}.0, Subtensor{int64::}.0)]]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

I updated the notebook to use Aesara and the new version of PyMC: stats/RL_PyMC.ipynb at master · ricardoV94/stats · GitHub

The shape problem is indeed new, and is solved by passing keepdims=True to the logsumexp function


Thank You!!!

Thanks so much for the amazing notebook! I have a quick question. If I have several independent blocks of learning data that I what to fit, is there a way to vectorized it in aesara as well? Should I use nested aesara.scan? so essentially i want the code below but avoid for loop to make it faster:
for block_idx in blocks:
Qs = 0.5 * at.ones((2,), dtype=‘float64’)
Qs, updates = aesara.scan(
sequences=[actions[block_idx], rewards[block_idx]],

If they are all the same length you should be able to run in a single scan (might need to adjust indexing and axis operations). Otherwise if they are similar in length you could pad the short ones with zeros or nan and index afterwards.

But there is nothing wrong with using a double scan, and thag might be the best solution for your case. It’s also probably the easiest way to get started. If you find speed issues you can try to get rid of it later.

Thanks for the advice! Another question is, is there a way to deal with NaN values in my RL lik function? What I used to do is to is first retrieve the indices of the data (as numpy matrix) where the value is not NaN through: idx = np.argwhere(~pd.isnull(rewards)).flatten(). Then I update the Q matrix using these idices. However I don’t seem to find the aesara equivalent of pd.isnull.

There might be something like an at.isnan? Usually if there’s a numpy function there should be an Aesara equivalent…

1 Like

It seems that aesara doesn’t have argwhere. I’m trying to find the indices of a element in my array, which in numpy would be np.argwhere but in aesara there isn’t such a function. The work around I thought of is applying at.nonzero to an array of logical variables encoding where that location is equal to the element. Just in case anyone else ran into the same issue.

1 Like