Fitting a simple Reinforcement Learning model to behavioral data with PyMC (Jupyter NB)


I recently wrote a Jupyter Notebook illustrating how one can fit a very simple RL model to simulated behavioral data. I thought this could be useful for others, specially when writing models that require recursive calculations based on sampled parameters:

Link: Fitting a Reinforcement Learning Model to Behavioral Data with PyMC — PyMC example gallery

Have a nice day :slight_smile:


Thank you for sharing, this looks great :tada:


FYI: @Maria also share her RL models on human data here Modeling reinforcement learning of human participant using PyMC3


I’m loving this sharing mood :wink:


Is there a version with pymc 4 with aesara :pray:? I try to use your code and I get an Input dimension mismatch error.

def update_Q(action, reward,
    This function updates the Q table according to the RL update rule.
    It will be called by theano.scan to do so recursevely, given the observed data and the alpha parameter
    This could have been replaced be the following lamba expression in the theano.scan fn argument:
        fn=lamba action, reward, Qs, alpha: tt.set_subtensor(Qs[action], Qs[action] + alpha * (reward - Qs[action]))

    Qs = at.set_subtensor(Qs[action], Qs[action] + alpha * (reward - Qs[action]))
    return Qs

def theano_llik_td(alpha, beta, actions, rewards):
    rewards = aesara.shared(np.asarray(rewards, dtype='int16'))
    actions = aesara.shared(np.asarray(actions, dtype='int16'))

    # Compute the Qs values
    Qs = 0.5 * at.ones((2), dtype='float64')
    Qs, updates = aesara.scan(
        sequences=[actions, rewards],

    # Apply the sotfmax transformation
    Qs_ = Qs[:-1] * beta
    log_prob_actions = Qs_ - pm.math.logsumexp(Qs_, axis=1)

    # Calculate the negative log likelihod of the observed actions
    log_prob_actions = log_prob_actions[at.arange(actions.shape[0]-1), actions[1:]]
    return at.sum(log_prob_actions)  # PyMC makes it negative by default

with pm.Model() as m:
    alpha = pm.Beta('alpha', 1, 1)
    beta = pm.HalfNormal('beta', 10)
    like = pm.Potential('like', theano_llik_td(alpha, beta, actions, rewards))
    tr = pm.sample()
ValueError                                Traceback (most recent call last)
File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/, in Function.__call__(self, *args, **kwargs)
    960 try:
    961     outputs = (
--> 962         self.fn()
    963         if output_subset is None
    964         else self.fn(output_subset=output_subset)
    965     )
    966 except Exception:

ValueError: Input dimension mismatch. One other input has shape[1] = 2, but input[1].shape[1] = 99.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Input In [12], in <cell line: 7>()
      9 beta = pm.HalfNormal('beta', 10)
     10 like = pm.Potential('like', theano_llik_td(alpha, beta, actions, rewards))
---> 11 tr = pm.sample()

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    531         [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
    532"Auto-assigning NUTS sampler...")
--> 533     initial_points, step = init_nuts(
    534         init=init,
    535         chains=chains,
    536         n_init=n_init,
    537         model=model,
    538         random_seed=random_seed_list,
    539         progressbar=progressbar,
    540         jitter_max_retries=jitter_max_retries,
    541         tune=tune,
    542         initvals=initvals,
    543         **kwargs,
    544     )
    546 if initial_points is None:
    547     # Time to draw/evaluate numeric start points for each chain.
    548     ipfns = make_initial_point_fns_per_chain(
    549         model=model,
    550         overrides=initvals,
    551         jitter_rvs=filter_rvs_to_jitter(step),
    552         chains=chains,
    553     )

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, **kwargs)
   2480"Initializing NUTS using {init}...")
   2482 cb = [
   2483     pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
   2484     pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
   2485 ]
-> 2487 initial_points = _init_jitter(
   2488     model,
   2489     initvals,
   2490     seeds=random_seed_list,
   2491     jitter="jitter" in init,
   2492     jitter_max_retries=jitter_max_retries,
   2493 )
   2495 apoints = [ for point in initial_points]
   2496 apoints_data = [ for apoint in apoints]

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries)
   2379 if i < jitter_max_retries:
   2380     try:
-> 2381         model.check_start_vals(point)
   2382     except SamplingError:
   2383         # Retry with a new seed
   2384         seed = rng.randint(2**30, dtype=np.int64)

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/, in Model.check_start_vals(self, start)
   1716     valid_keys = ", ".join(self.named_vars.keys())
   1717     raise KeyError(
   1718         "Some start parameters do not appear in the model!\n"
   1719         f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
   1720     )
-> 1722 initial_eval = self.point_logps(point=elem)
   1724 if not all(np.isfinite(v) for v in initial_eval.values()):
   1725     raise SamplingError(
   1726         "Initial evaluation of model at starting point failed!\n"
   1727         f"Starting values:\n{elem}\n\n"
   1728         f"Initial evaluation results:\n{initial_eval}"
   1729     )

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/, in Model.point_logps(self, point, round_vals)
   1757 factors = self.basic_RVs + self.potentials
   1758 factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
   1759 return {
   1760 np.round(np.asarray(factor_logp), round_vals)
   1761     for factor, factor_logp in zip(
   1762         factors,
-> 1763         self.compile_fn(factor_logps_fn)(point),
   1764     )
   1765 }

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/, in PointFunc.__call__(self, state)
   1861 def __call__(self, state):
-> 1862     return self.f(**state)

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/, in Function.__call__(self, *args, **kwargs)
    973     if hasattr(self.fn, "thunks"):
    974         thunk = self.fn.thunks[self.fn.position_of_error]
--> 975     raise_with_op(
    976         self.maker.fgraph,
    977         node=self.fn.nodes[self.fn.position_of_error],
    978         thunk=thunk,
    979         storage_map=getattr(self.fn, "storage_map", None),
    980     )
    981 else:
    982     # old-style linkers raise their own exceptions
    983     raise

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    533     warnings.warn(
    534         f"{exc_type} error does not allow us to add an extra error message"
    535     )
    536     # Some exception need extra parameter in inputs. So forget the
    537     # extra long error message in that case.
--> 538 raise exc_value.with_traceback(exc_trace)

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/, in Function.__call__(self, *args, **kwargs)
    959 t0_fn = time.time()
    960 try:
    961     outputs = (
--> 962         self.fn()
    963         if output_subset is None
    964         else self.fn(output_subset=output_subset)
    965     )
    966 except Exception:
    967     restore_defaults()

ValueError: Input dimension mismatch. One other input has shape[1] = 2, but input[1].shape[1] = 99.
Apply node that caused the error: Elemwise{Sub}[(0, 0)](Elemwise{Mul}[(0, 0)].0, Elemwise{Composite{(i0 + log(i1))}}[(0, 0)].0)
Toposort index: 68
Inputs types: [TensorType(float64, (None, None)), TensorType(float64, (1, None))]
Inputs shapes: [(99, 2), (1, 99)]
Inputs strides: [(16, 8), (792, 8)]
Inputs values: ['not shown', 'not shown']
Outputs clients: [[AdvancedSubtensor(Elemwise{Sub}[(0, 0)].0, ARange{dtype='int64'}.0, Subtensor{int64::}.0)]]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

I updated the notebook to use Aesara and the new version of PyMC: stats/RL_PyMC.ipynb at master · ricardoV94/stats · GitHub

The shape problem is indeed new, and is solved by passing keepdims=True to the logsumexp function


Thank You!!!

Thanks so much for the amazing notebook! I have a quick question. If I have several independent blocks of learning data that I what to fit, is there a way to vectorized it in aesara as well? Should I use nested aesara.scan? so essentially i want the code below but avoid for loop to make it faster:
for block_idx in blocks:
Qs = 0.5 * at.ones((2,), dtype=‘float64’)
Qs, updates = aesara.scan(
sequences=[actions[block_idx], rewards[block_idx]],

If they are all the same length you should be able to run in a single scan (might need to adjust indexing and axis operations). Otherwise if they are similar in length you could pad the short ones with zeros or nan and index afterwards.

But there is nothing wrong with using a double scan, and that might be the best solution for your case. It’s also probably the easiest way to get started. If you find speed issues you can try to get rid of it later.

Thanks for the advice! Another question is, is there a way to deal with NaN values in my RL lik function? What I used to do is to is first retrieve the indices of the data (as numpy matrix) where the value is not NaN through: idx = np.argwhere(~pd.isnull(rewards)).flatten(). Then I update the Q matrix using these idices. However I don’t seem to find the aesara equivalent of pd.isnull.

There might be something like an at.isnan? Usually if there’s a numpy function there should be an Aesara equivalent…

1 Like

It seems that aesara doesn’t have argwhere. I’m trying to find the indices of a element in my array, which in numpy would be np.argwhere but in aesara there isn’t such a function. The work around I thought of is applying at.nonzero to an array of logical variables encoding where that location is equal to the element. Just in case anyone else ran into the same issue.

1 Like

Hi! Thanks to your tutorial, I managed to fit my RL model. However, when I try to involve all subjects (97 of them), it takes forever to even compile the model. Here is my model:

like_vec = lambda data, alpha, beta, forget: \
    at.sum(at.as_tensor_variable([agent.aesara_llik_td(data[i],alpha[i], beta[i], forget[i]) for i in range(n_subj)])) 
with pm.Model() as m:
    alpha = pm.Beta('alpha', alpha=1, beta=1, shape=n_subj)
    beta = pm.Gamma('beta', alpha=1, beta=1, shape=n_subj)
    forget = pm.Beta('forget', alpha=1, beta=1, shape=n_subj)

    like = pm.Potential('like', like_vec(data, alpha, beta, forget))
    tr = pm.sample(draws=3000,tune = 1500, chains=4, cores=4)

If I set n_subj to be 10, it takes about 50s to compile the model. And then I see messages being printed and a progress bar showing. However, when I increase n_subj to 94 (which is the number of subjects I have), it took over two days to run and I still don’t see any message being printed nor a progress bar! What could be some of the causes for this? In the case of 2 subjects, it only took 10 minutes to fit the entire model. It couldn’t have taken this long. Did I do anything wrong?

You could try to vectorize the outer loop (across subjects), either directly if the sequences have the same length or with an outer scan if they don’t. That could produce a more efficient graph than what you have with 100 independent scans.

1 Like

That might help,
It’s on aversive learning but should be relatively easy to adjust. And also based on @ricardoV94 's work!


A follow up question: Do you have recommendations on how to compare RL models using waic? I tried az.waic(tr) on the tr produced by my RL model fitting, however, I get

TypeError: log likelihood not found in inference data object

I search some of the prior discussions, it seems like replacing pm.Potential with pm.DensityDist solves this issue. Is this the recommended way? I tried pm.DensityDist and then I started to get memory issues, which automatically restarts my computer when the model fitting is about 63.59% done. I never had this issue before with pm.Potential. However, if I use pm.Potential and simultaneous fit 3 models on 3 separate notebooks, I did get the same problem. Maybe I didn’t write the code in the most memory-efficient way, which pm.densityDist simply pushed it over the edge? :

Can you show the code with the DensityDist? It should be possible to do it that way. You just have to make sure you don’t collapse multiple observations as model comparison requires the elementwise loglikelihood.

Potentials can’t be used when doing model comparison, because PyMC doesn’t know if the Potential is supposed to be added as a prior or likelihood term (and only the latter matters).

We discussed with @junpenglao and @aloctavodia adding such distinction to Potentials.

1 Like

It also seems like you may be using an older version of pymc/aesara, try to update and see if it improves as some scan bugs have been fixed recently

1 Like

I updated pymc to 4.1.3. The memory problem persisted, however. I think it was just because DensityDist requires more memory because it has to save the likelihood information. I worked around it by reducing the number of cores being used. So it takes a longer time but less memory. I found another issue though, which potentially could be a bug in the new version. When I sample with 2 cores or above:
tr = pm.sample(draws=3000,tune = 1000, chains=4, cores=2)
I get the message:
[29/16000 05:18<48:45:38 Sampling 4 chains, 0 divergences]
However, if I only change cores = 1, I get instead:
[11/4000 02:18<13:59:05 Sampling chain 0, 0 divergences]
which doesn’t look right.

No, that shouldn’t matter.

Regarding the time expectations, I wouldn’t read too much early in sampling, specially with such slow sampling. The differences might be just due to different starting points.