Learning pytensor scan function utilizing a simple population model

Hi community!

I am trying to learn to implement pytensors scan function - to eventually implement in a pymc model to estimate population level parameters. I thought I would start with a simple demographic population model but I am obviously not understanding some basic concepts. Any help would be much appreciated!

import pytensor
import pytensor.tensor as pt
import numpy as np
import sympy as sym

# define demographic variables
s1, s2, f = sym.symbols('s1 s2 f')

# define symbolic demographic matrix
A_def = sym.Matrix((
    [0, 0, 0.5*f*s2],
    [s1, 0, 0],
    [0, s1, s2]

# define vital rates - currently static, but will eventual be estimated with stoch and used for every pop_step()
vr = {

def pop_step(abundance_vector, vr):

    transition_matrix = sym.matrix2numpy(A_def.subs(vr), dtype=float)
    abundance_vector = transition_matrix @ abundance_vector
    return abundance_vector

# population vector, last value of abundance_vector will used for each future pop_step()
abundance_vector = pt.vector('abundance_vector')

# outputs info 
outputs_info = [dict(initial=abundance_vector, taps=[-1])]

# loop over pop_step for n_step = 50 'years'
results, updates = pytensor.scan(fn=pop_step, sequences=[abundance_vector] , outputs_info=outputs_info, n_steps=50)

# set starting population for t=0
abundance_n0 = np.array([100, 145, 210])

# define pytensor function
f = pytensor.function(inputs=[abundance_vector], outputs= results, updates=updates)

# evaluate function

ValueError: Length of abundance_vector[t-1] cannot be determined