Trying to speed up a model with custom likelihood

Hi folks,

I am modelling the data from an experiment where participants learn a simple language over several (200) trials. In each trial, the participant sees four images and a sentence in an unknown language (always 3 words which could mean e.g. “Blue hits circle”) that describes one of the scenes (e.g. the scene where a blue object with an arm hits a circle). There is a total of 4 objects and 3 actions. After seeing the four scenes and the sentence, the participant picks a scene and they get feedback about which of the scenes the sentence actually referred to. Over the experiment, they are meant to learn the meaning of each word and also the word order of the language (e.g. subject verb object).

I am trying to model the actual learning process of each participant, in a hierarchical fashion. As you can imagine, the model includes quite a complex likelihood function, including a big scan over trials with a set_subtensor inside it. Here’s a gist with the model (excuse the references to theano). The code in the gist runs parameter recovery with a single prior sample with a much smaller size than the final dataset, which has 200 trials and more than 200 participants. At the moment, fitting with the full dataset takes more than 120 hours, which unfortunately is the hard limit on the server I am using. Here’s what I have tried:

  • Rewriting the scan as a loop and running with sample_blackjax_nuts using the GPU. This didn’t seem to make it any faster (in fact it made is slower). I’m no JAX expert though so I might have missed something.
  • Stop and restarted the run in a different server job. This is difficult for reasons treated recently in this post.
  • Running variational inference instead. Unfortunately doing parameter recovery with this showed pretty poor recovery (in contrast to some tests with NUTS).

I was wondering if there is any way to simplify the likelihood function, or even getting rid of the scan. Any help would be much appreciated. Thank you!