Declare priors using loop results in bracket nesting level exceeded maximum

You need to sort your variable into three groups:

  • sequences, stuff you want to iterate over
  • outputs_info, stuff you want to compute on then recursively feed back into the loop at every iteration
  • non_sequences, stuff you want to carry along into every loop iteration

I think you have no sequences, res and y0_vec will be outputs_info, and mu and A will be non_sequences.

After sorting, you write the inner function. You have to sort first because the inputs to the inner function have to be in a certain order: sequences, then outputs_info, then non_sequences. So if my classification is right, you’ll have an inner function like this:

def step(last_y0, last_res, A, mu):
    new_y0 = A @ (pt.diag(last_y0) @ pt.exp(mu) / 2)
    new_res = last_res + new_y0
    return new_y0, new_res

Then you give this all to scan:

(y0_sequence, res_sequence), updates = pytensor.scan(step, 
                                                     outputs_info=[y0_vec, y0_vec], 
                                                     non_sequences=[A, mu],
                                                     n_steps=n_generations) 

Warning: I didn’t test any of this code.

Also, all of this is better explained in the scan docs, found here.