Using PyTensor.scan to loop over multidimensional array

Scan always loops over the first dimension of stuff passed to the sequences argument, so you could permute the axes to put the “batch” dimension first. Another strategy would be to give a pt.arange as the sequence argument and the tensors as non_sequences, then index the tensors in the inner function. Here’s an example of the 2nd strategy, here’s an example of the 1st.

1 Like