You need to sort your variable into three groups:
sequences, stuff you want to iterate overoutputs_info, stuff you want to compute on then recursively feed back into the loop at every iterationnon_sequences, stuff you want to carry along into every loop iteration
I think you have no sequences, res and y0_vec will be outputs_info, and mu and A will be non_sequences.
After sorting, you write the inner function. You have to sort first because the inputs to the inner function have to be in a certain order: sequences, then outputs_info, then non_sequences. So if my classification is right, you’ll have an inner function like this:
def step(last_y0, last_res, A, mu):
new_y0 = A @ (pt.diag(last_y0) @ pt.exp(mu) / 2)
new_res = last_res + new_y0
return new_y0, new_res
Then you give this all to scan:
(y0_sequence, res_sequence), updates = pytensor.scan(step,
outputs_info=[y0_vec, y0_vec],
non_sequences=[A, mu],
n_steps=n_generations)
Warning: I didn’t test any of this code.
Also, all of this is better explained in the scan docs, found here.