How to write a PyTensor Op to wrap Jax ODEs with multiple input parameters

@POde97 I don’t quite get what is the problem you are facing?

You have an Op with multiple outputs and verify_grad can’t handle those? If you just want to see if the outputs are correct you can create an Op that outputs only one set of outputs at a time.

It’s also fine to not have gradients defined for all the inputs/outputs of an Op, you can use one of:

https://pytensor.readthedocs.io/en/latest/library/gradient.html#pytensor.gradient.grad_not_implemented
https://pytensor.readthedocs.io/en/latest/library/gradient.html#pytensor.gradient.grad_undefined