How to write a PyTensor Op to wrap Jax ODEs with multiple input parameters

Hey Ricardo thank you for the fast answer. I’m trying to generalise the op for the following 2 dimensional system but I’m not understanding how to do it properly. I attach my code perhaps it is easier to understand this way. Thank you!

import diffrax
import matplotlib.pyplot as plt
import numpy as np

import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify

import jax
import jax.numpy as jnp

import pymc as pm
import pymc.sampling.jax



vector_field = lambda t, y, args: -args*y




def rhs(t,y, p):
    S, P =  y
    vmax, K_S =  p
    dPdt = vmax * (S / K_S + S)
    dSdt = -dPdt

    dy = dSdt,dPdt
    return dy

term = diffrax.ODETerm(rhs)
solver = diffrax.Dopri5()
saveat = diffrax.SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)

sol = diffrax.diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=(1,1), saveat=saveat,
                  stepsize_controller=stepsize_controller,args=(2.0,1.0))

print(sol.ts)
print(sol.ys)


def sol_op_jax(y0):
    print(y0[0],y0[1])
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0=0,
        t1=55,
        dt0=0.1,
        y0=(y0[0],y0[1]),
        args =(y0[2],y0[3]),
        saveat=saveat,
        stepsize_controller=stepsize_controller
    )

    return sol.ys

jitted_sol_op_jax = jax.jit(sol_op_jax)
def vjp_sol_op_jax(y0, gz):
    _, vjp_fn = jax.vjp(sol_op_jax, y0)
    
    return vjp_fn(gz)[0],vjp_fn(gz)[1]

jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax)

class SolOp(Op):
    def make_node(self, y0):
        inputs = [pt.as_tensor_variable(y0)]
        # Assume the output to always be a float64 vector
        outputs = [pt.vector(dtype="float64")]#,[pt.vector(dtype="float64")]]#,pt.vector(dtype="float64")])]#,pt.vector(dtype="float64")
        return Apply(self, inputs, outputs)#[0]),Apply(self, inputs, outputs[1])

    def perform(self, node, inputs, outputs):
        (y0,) = inputs
        result = jitted_sol_op_jax(y0)
        
    
        outputs[0][0] = np.asarray(result[0], dtype="float64")
        
        outputs[1][0] = np.asarray(result[1], dtype="float64")

    def grad(self, inputs, output_gradients):
        (y0,) = inputs
        (gz,) = output_gradients
        return [vjp_sol_op(y0, gz)]


class VJPSolOp(Op):
    def make_node(self, y0, gz):
        inputs = [pt.as_tensor_variable(y0), pt.as_tensor_variable(gz)]
        outputs = [inputs[0].type()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (y0, gz) = inputs
        result = jitted_vjp_sol_op_jax(y0, gz)
        outputs[0][0] = np.asarray(result[0], dtype="float64")
        print('there',outputs[0][0])
        outputs[1][0] = np.asarray(result[1], dtype="float64")

sol_op = SolOp()
vjp_sol_op = VJPSolOp()

pytensor.gradient.verify_grad(sol_op,(np.array([0.1, 0.02,0.1,0.1]),), rng=np.random.default_rng())

—ERROR------

IndexError: list index out of range
Apply node that caused the error: SolOp(input 0)
Toposort index: 0
Inputs types: [TensorType(float64, shape=(4,))]
Inputs shapes: [(4,)]
Inputs strides: [(8,)]
Inputs values: [array([0.1 , 0.02, 0.1 , 0.1 ])]
Outputs clients: [['output']]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/home/pymc_env/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/home/pymc_env/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/home/pymc_env/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/home/pymc_env/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_6714/910888793.py", line 106, in <module>
    pytensor.gradient.verify_grad(sol_op,(np.array([0.1, 0.02,0.1,0.1]),), rng=np.random.default_rng())
  File "/home/pymc_env/lib/python3.11/site-packages/pytensor/gradient.py", line 1796, in verify_grad
    o_output = fun(*tensor_pt)
  File "/home/pymc_env/lib/python3.11/site-packages/pytensor/graph/op.py", line 295, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/tmp/ipykernel_6714/910888793.py", line 72, in make_node
    outputs = [pt.vector(dtype="float64")]#,[pt.vector(dtype="float64")]]#,pt.vector(dtype="float64")])]#,pt.vector(dtype="float64")

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.