The input signature for scan is a bit complex and explained in the docstring for scan:
fn
`fn` is a function that describes the operations involved in one
step of `scan`. `fn` should construct variables describing the
output of one iteration step. It should expect as input
`Variable`\s representing all the slices of the input sequences
and previous values of the outputs, as well as all other arguments
given to scan as `non_sequences`. The order in which scan passes
these variables to `fn` is the following :
* all time slices of the first sequence
* all time slices of the second sequence
* ...
* all time slices of the last sequence
* all past slices of the first output
* all past slices of the second output
* ...
* all past slices of the last output
* all other arguments (the list given as `non_sequences` to
`scan`)
The order of the sequences is the same as the one in the list
`sequences` given to `scan`. The order of the outputs is the same
as the order of `outputs_info`. For any sequence or output the
order of the time slices is the same as the one in which they have
been given as taps. For example if one writes the following :
.. code-block:: python
scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])
, Sequence2
, dict(input = Sequence3, taps = 3) ]
, outputs_info = [ dict(initial = Output1, taps = [-3,-5])
, dict(initial = Output2, taps = None)
, Output3 ]
, non_sequences = [ Argument1, Argument2])
`fn` should expect the following arguments in this given order:
#. ``sequence1[t-3]``
#. ``sequence1[t+2]``
#. ``sequence1[t-1]``
#. ``sequence2[t]``
#. ``sequence3[t+3]``
#. ``output1[t-3]``
#. ``output1[t-5]``
#. ``output3[t-1]``
#. ``argument1``
#. ``argument2``
So in your case, a, b, and c are recursive computation – they go in outputs_info
. If d and e are sequences, they go into sequences
. The signature of your inner function will be inner(d, e, a, b, c)
, because according to the documentation the general order of inputs is sequences, then outputs_info, then non-sequences (the order within those groups is defined by the order you provide them to the scan function, so I’m assuming you have pytensor.scan(inner, sequences=[d, e], outputs_info=[a,b,c])
. You will return only a,b,c
from your inner function.