How to write a PyTensor Op to wrap Jax ODEs with multiple input parameters

I am new to PyMC and Jax. I am trying to use PyMC to estimate the parameters for ODE implemented in Jax. The ODEs I work with often have many input parameters so I would like to write a PyTensor Op that can handle this. I have successfully used sunode, but there are some features from Jax and diffrax that I would like to use.

I am trying to modify the Ops written in the tutorial https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/ to handle multiple parameters. However, I am running into several errors which I assume are related to problems in my implementation.

Here is the code I have written. In the tutorial, only the initial condition y0 was estimated, but I added an additional parameter args to infer as well.

vector_field = lambda t, y, args: -args*y
term = diffrax.ODETerm(vector_field)
solver = diffrax.Dopri5()
saveat = diffrax.SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)

sol = diffrax.diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
                  stepsize_controller=stepsize_controller,args=2.0)

print(sol.ts)
print(sol.ys) 

def sol_op_jax(y0, args):
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0=0,
        t1=3,
        dt0=0.1,
        y0=y0,
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        args=args,
    )
    return sol.ys

jitted_sol_op_jax = jax.jit(sol_op_jax)

def vjp_sol_op_jax(inputs, output_grads):
    _, vjp_fn = jax.vjp(sol_op_jax, inputs)
    return vjp_fn(output_grads)

jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax)

class SolOp(Op):
    def make_node(self, *inputs):
        inputs = [pt.as_tensor_variable(inp) for inp in inputs]
        # Assume the output to always be a float64 vector
        outputs = [pt.dvector()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (y0, args,) = inputs
        result = jitted_sol_op_jax(y0, args)
        outputs[0][0] = np.asarray(result, dtype="float64")

    def grad(self, inputs, output_gradients):
        (y0, args) = inputs
        (gz,) = output_gradients
        return vjp_sol_op(y0, args, gz)

class VJPSolOp(Op):
    def make_node(self, y0, args, output_grads):
        inputs = [pt.as_tensor_variable(y0), pt.as_tensor_variable(args)]
        inputs.append(pt.as_tensor_variable(output_grads))
        outputs = [inputs[0].type(), inputs[1].type()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (y0, args, output_grads) = inputs
        result = jitted_vjp_sol_op_jax((y0, args), output_grads)
        outputs[0][0] = np.asarray(result, dtype="float64")

sol_op = SolOp()
vjp_sol_op = VJPSolOp()

pytensor.gradient.verify_grad(sol_op, (np.array(3.0), np.array(2.0)), rng=np.random.default_rng())

I am currently getting the following error when I run this:

TypeError: sol_op_jax() missing 1 required positional argument: 'args'
Apply node that caused the error: VJPSolOp(input 0, input 1, random_projection)
Toposort index: 0
Inputs types: [TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (?,))]
Inputs shapes: [(), (), (4,)]
Inputs strides: [(), (), (8,)]
Inputs values: [array(3.), array(2.), array([1.18777319, 0.55691297, 1.30661623, 1.09379087])]
Outputs clients: [['output'], ['output']]

Ideally, I would like to be able to pass a list of tuple of parameters to sol_op_jax(). Happy to add the full stack trace or any more info if that helps.

Thanks!

I think your first problem is in vjp_sol_op_jax, have you tried to call that directly before wrapping anything in PyTensor? Are you sure the inputs should be in a tuple?

Then, the Apply returned by VJPSolOp.make_node should have 4 inputs: y0, args and the two gradients, but you say it only has two inputs, y0 and args.

The perform method of VJPSolOp also seems problematic. You should have two outputs from jitted_vjp_sol_op_jax, which should be saved in outputs[0][0] and outputs[1][0] respectively.

Thank you! I got it working! I appreciate the help!

1 Like