Scan multi-output function with state variable a subset of outputs

The input signature for scan is a bit complex and explained in the docstring for scan:

fn
    `fn` is a function that describes the operations involved in one
    step of `scan`. `fn` should construct variables describing the
    output of one iteration step. It should expect as input
    `Variable`\s representing all the slices of the input sequences
    and previous values of the outputs, as well as all other arguments
    given to scan as `non_sequences`. The order in which scan passes
    these variables to `fn`  is the following :

    * all time slices of the first sequence
    * all time slices of the second sequence
    * ...
    * all time slices of the last sequence
    * all past slices of the first output
    * all past slices of the second output
    * ...
    * all past slices of the last output
    * all other arguments (the list given as `non_sequences` to
        `scan`)

    The order of the sequences is the same as the one in the list
    `sequences` given to `scan`. The order of the outputs is the same
    as the order of `outputs_info`. For any sequence or output the
    order of the time slices is the same as the one in which they have
    been given as taps. For example if one writes the following :

    .. code-block:: python

        scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])
                             , Sequence2
                             , dict(input =  Sequence3, taps = 3) ]
               , outputs_info = [ dict(initial =  Output1, taps = [-3,-5])
                                , dict(initial = Output2, taps = None)
                                , Output3 ]
               , non_sequences = [ Argument1, Argument2])

    `fn` should expect the following arguments in this given order:

    #. ``sequence1[t-3]``
    #. ``sequence1[t+2]``
    #. ``sequence1[t-1]``
    #. ``sequence2[t]``
    #. ``sequence3[t+3]``
    #. ``output1[t-3]``
    #. ``output1[t-5]``
    #. ``output3[t-1]``
    #. ``argument1``
    #. ``argument2``

So in your case, a, b, and c are recursive computation – they go in outputs_info. If d and e are sequences, they go into sequences. The signature of your inner function will be inner(d, e, a, b, c), because according to the documentation the general order of inputs is sequences, then outputs_info, then non-sequences (the order within those groups is defined by the order you provide them to the scan function, so I’m assuming you have pytensor.scan(inner, sequences=[d, e], outputs_info=[a,b,c]). You will return only a,b,c from your inner function.

1 Like