How to write a PyTensor Op to wrap Jax ODEs with multiple input parameters

Hi @nlinden, did you managed to do the code? Can you share it with me?