Thanks again for your great help.
For the sake of completeness, here is the working code.
As scan saves all sequences, new_res is not needed
def step(last_y0, A, mu):
new_y0 = A @ (last_y0 * pt.exp(mu) / 2)
return new_y0
y0_vec = pm.math.concatenate([[y0_R], pt.zeros(dim - 1)], axis=0)
y0_sequence, updates = pytensor.scan(step,
outputs_info=y0_vec,
non_sequences=[A, mu],
n_steps=N_GEN)
# We only care about the merged vector
y0 = y0_sequence.sum(axis=0) + y0_vec