Hi all,
In Numpyro, I can scan over function inputs with a function that outputs two tuples with all the inputs to the next iteration in the first tuple. I can scan over a function like
def fn(state, xs):
a,b,c = state
d,e = xs
a,b,c = g(a,b,c,d,e)
updated_state = (a,b,c)
output = h(a,b,c,d,e)
return updated_state, output
where xs
is the tuple containing the iterables. Can I do something similar with Pytensor scan? i.e., only a subset of outputs is used as inputs to the next iteration?
Thanks!
If the xs
are exogenous variables of length T
(the length of the scan), you can pass them in via the sequences
argument. The state
will be passed in via the outputs_info
argument. See here for everything you ever wanted to know about scan.
1 Like
Many thanks for the quick reply! The RNN seems like the relevant example. It seems that to get what I want, I need a function that has all output variables as inputs also, even if the function doesn’t act on some of them? I think I can live with that.
It depends on the nature of the variables d,e
. If they are static parameters that remain the same at every iteration, you can use non_sequences
. If they are vectors of pre-computed parameters, you should use sequences
.
The xs
are indeed sequences, but my point was that in Pymc scan it seems I need to have def fn(state, output, xs)
. Maybe I’m wrong but I don’t see a way to define the outputs_info
so that if my function returns two tuples or lists, only the first of them is inputted into the next iteration.
The input signature for scan is a bit complex and explained in the docstring for scan:
fn
`fn` is a function that describes the operations involved in one
step of `scan`. `fn` should construct variables describing the
output of one iteration step. It should expect as input
`Variable`\s representing all the slices of the input sequences
and previous values of the outputs, as well as all other arguments
given to scan as `non_sequences`. The order in which scan passes
these variables to `fn` is the following :
* all time slices of the first sequence
* all time slices of the second sequence
* ...
* all time slices of the last sequence
* all past slices of the first output
* all past slices of the second output
* ...
* all past slices of the last output
* all other arguments (the list given as `non_sequences` to
`scan`)
The order of the sequences is the same as the one in the list
`sequences` given to `scan`. The order of the outputs is the same
as the order of `outputs_info`. For any sequence or output the
order of the time slices is the same as the one in which they have
been given as taps. For example if one writes the following :
.. code-block:: python
scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])
, Sequence2
, dict(input = Sequence3, taps = 3) ]
, outputs_info = [ dict(initial = Output1, taps = [-3,-5])
, dict(initial = Output2, taps = None)
, Output3 ]
, non_sequences = [ Argument1, Argument2])
`fn` should expect the following arguments in this given order:
#. ``sequence1[t-3]``
#. ``sequence1[t+2]``
#. ``sequence1[t-1]``
#. ``sequence2[t]``
#. ``sequence3[t+3]``
#. ``output1[t-3]``
#. ``output1[t-5]``
#. ``output3[t-1]``
#. ``argument1``
#. ``argument2``
So in your case, a, b, and c are recursive computation – they go in outputs_info
. If d and e are sequences, they go into sequences
. The signature of your inner function will be inner(d, e, a, b, c)
, because according to the documentation the general order of inputs is sequences, then outputs_info, then non-sequences (the order within those groups is defined by the order you provide them to the scan function, so I’m assuming you have pytensor.scan(inner, sequences=[d, e], outputs_info=[a,b,c])
. You will return only a,b,c
from your inner function.
1 Like
Ok, I know how it works now. I just need to add None
to the outputs_info
list for outputs that don’t feed back to scan.
Thanks!