Scan always loops over the first dimension of stuff passed to the sequences
argument, so you could permute the axes to put the “batch” dimension first. Another strategy would be to give a pt.arange
as the sequence argument and the tensors as non_sequences
, then index the tensors in the inner function. Here’s an example of the 2nd strategy, here’s an example of the 1st.
1 Like