Model extract ODE DifferentialEquation states in model

Hi, I have a model where everything works as I want when i run it like so (I have taken a lot of info away to make this simple):

observed_data = output_data.to_numpy()

model_ode_system = DifferentialEquation(
    func=ode_sys,
    times=model_t_eval,
    n_states=n_states,
    n_theta=n_theta,
    t0 = model_t_eval[0]
)

with pymc.Model() as model:
    # Priors

    ## Non-fixed parameters -- included in likelihood function
    p1  =  pymc.TruncatedNormal("p1", mu=2, sigma=0.2,  initval=2, lower=0)
    ......

    ## Fixed parameters
    p2 =       5
    ....

    # theta - parameters of the ODE
    model_theta = [
       p1, p2, ....
    ]
  
    # ODE system solution
    ode_solution = model_ode_system(y0=model_y0, theta=model_theta)

    # Working:
    sigma = pymc.HalfNormal("sigma", sigma=1)
    pymc.Normal('Y_obs', mu=ode_solution, sigma=sigma, observed=observed_data)


print("Model created successfully.")

How would I go about extracting states from my ode_solution if my observed data was different?

Say I had 3 states, x1, x2 and x3. And they can be used for the outputs:
y1 = x1 + x2
y2 = x3

(just a random example)

And my observed data was given for y (not x).

Something like:

observed_data = train_data.to_numpy()

model_ode_system = DifferentialEquation(
    func=ode_sys,
    times=model_t_eval,
    n_states=n_states,
    n_theta=n_theta,
    t0 = model_t_eval[0]
)

with pymc.Model() as model:
    # Priors

    ## Non-fixed parameters -- included in likelihood function
    p1  =       pymc.TruncatedNormal("p1", mu=2, sigma=0.2,  initval=2, lower=0)
    ......

    ## Fixed parameters
    p2 =       5
    ....

    # theta - parameters of the ODE
    model_theta = [
       p1, p2, ....
    ]
  
    # ODE system solution
    ode_solution = model_ode_system(y0=model_y0, theta=model_theta)

    # Determine y1 and y2 from states:
   mu_y1 = ode_solution[:,1] + ode_solution[:,2] # If column 1 is x1, column 2 is x2
   mu_y2 = ode_solution[:,3] # column 3 is x3

   OR

   mu_y1 = ode_solution['x1'] + ode_solution['x2]
   mu_y3 = ode_solution['x3']

   sigma_y1 = pymc.HalfNormal("sigma_y1", sigma=1)
   sigma_y2 = pymc.HalfNormal("sigma_y2", sigma=1)
 
   pymc.Normal('y1_obs', mu=mu_y1, sigma=sigma_y1, observed=observed_data[:,1]) 
   pymc.Normal('y2_obs', mu=mu_y2, sigma=sigma_y2, observed=observed_data[:,2]) 

print("Model created successfully.")