How to write a PyTensor Op to wrap Jax ODEs with multiple input parameters

I am new to PyMC and Jax. I am trying to use PyMC to estimate the parameters for ODE implemented in Jax. The ODEs I work with often have many input parameters so I would like to write a PyTensor Op that can handle this. I have successfully used sunode, but there are some features from Jax and diffrax that I would like to use.

I am trying to modify the Ops written in the tutorial https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/ to handle multiple parameters. However, I am running into several errors which I assume are related to problems in my implementation.

Here is the code I have written. In the tutorial, only the initial condition y0 was estimated, but I added an additional parameter args to infer as well.

vector_field = lambda t, y, args: -args*y
term = diffrax.ODETerm(vector_field)
solver = diffrax.Dopri5()
saveat = diffrax.SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)

sol = diffrax.diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
                  stepsize_controller=stepsize_controller,args=2.0)

print(sol.ts)
print(sol.ys) 

def sol_op_jax(y0, args):
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0=0,
        t1=3,
        dt0=0.1,
        y0=y0,
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        args=args,
    )
    return sol.ys

jitted_sol_op_jax = jax.jit(sol_op_jax)

def vjp_sol_op_jax(inputs, output_grads):
    _, vjp_fn = jax.vjp(sol_op_jax, inputs)
    return vjp_fn(output_grads)

jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax)

class SolOp(Op):
    def make_node(self, *inputs):
        inputs = [pt.as_tensor_variable(inp) for inp in inputs]
        # Assume the output to always be a float64 vector
        outputs = [pt.dvector()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (y0, args,) = inputs
        result = jitted_sol_op_jax(y0, args)
        outputs[0][0] = np.asarray(result, dtype="float64")

    def grad(self, inputs, output_gradients):
        (y0, args) = inputs
        (gz,) = output_gradients
        return vjp_sol_op(y0, args, gz)

class VJPSolOp(Op):
    def make_node(self, y0, args, output_grads):
        inputs = [pt.as_tensor_variable(y0), pt.as_tensor_variable(args)]
        inputs.append(pt.as_tensor_variable(output_grads))
        outputs = [inputs[0].type(), inputs[1].type()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (y0, args, output_grads) = inputs
        result = jitted_vjp_sol_op_jax((y0, args), output_grads)
        outputs[0][0] = np.asarray(result, dtype="float64")

sol_op = SolOp()
vjp_sol_op = VJPSolOp()

pytensor.gradient.verify_grad(sol_op, (np.array(3.0), np.array(2.0)), rng=np.random.default_rng())

I am currently getting the following error when I run this:

TypeError: sol_op_jax() missing 1 required positional argument: 'args'
Apply node that caused the error: VJPSolOp(input 0, input 1, random_projection)
Toposort index: 0
Inputs types: [TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (?,))]
Inputs shapes: [(), (), (4,)]
Inputs strides: [(), (), (8,)]
Inputs values: [array(3.), array(2.), array([1.18777319, 0.55691297, 1.30661623, 1.09379087])]
Outputs clients: [['output'], ['output']]

Ideally, I would like to be able to pass a list of tuple of parameters to sol_op_jax(). Happy to add the full stack trace or any more info if that helps.

Thanks!

I think your first problem is in vjp_sol_op_jax, have you tried to call that directly before wrapping anything in PyTensor? Are you sure the inputs should be in a tuple?

Then, the Apply returned by VJPSolOp.make_node should have 4 inputs: y0, args and the two gradients, but you say it only has two inputs, y0 and args.

The perform method of VJPSolOp also seems problematic. You should have two outputs from jitted_vjp_sol_op_jax, which should be saved in outputs[0][0] and outputs[1][0] respectively.

Thank you! I got it working! I appreciate the help!

1 Like

Hi Ricardo and all pymc community :smiley:

I’m still wondering how it’s possible to store something in outputs[1][0] because then verify_grad() do not support multiple outputs.

As JB suggested ‘we could make loop over outputs making random projections R for each,
but this doesn’t handle the case where not all the outputs are differentiable… so I leave this as TODO for now -JB.’
So the foor loop is not optimal. Do you have any idea on how to do that ? I would like to contribute, ( because for a system of coupled equation seems to be mandatory to store multiple outputs (am I correct ? ) ) but I cannot see any better idea.

For the one dimensional examples with one parameter (above) this is the solution by the way

import diffrax
import matplotlib.pyplot as plt
import numpy as np

import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify
import arviz as az
import jax
import jax.numpy as jnp

import pymc as pm
import pymc.sampling.jax


vector_field = lambda t, y, args: -args*y

term = diffrax.ODETerm(vector_field)
solver = diffrax.Dopri5()
saveat = diffrax.SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)

sol = diffrax.diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
                  stepsize_controller=stepsize_controller,args=2.0)


noise = np.random.normal(0, 0.1, sol.ys.shape)
noisy_array = sol.ys + noise


def sol_op_jax(y0):
    
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0=0,
        t1=55,
        dt0=0.1,
        y0=y0[0],
        args =y0[1],
        saveat=saveat,
        stepsize_controller=stepsize_controller
    )

    return sol.ys

jitted_sol_op_jax = jax.jit(sol_op_jax)

def vjp_sol_op_jax(y0, gz):
    _, vjp_fn = jax.vjp(sol_op_jax, y0)
    
    return vjp_fn(gz)[0]

jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax)


class SolOp(Op):
    def make_node(self, y0):
        inputs = [pt.as_tensor_variable(y0)]
        # Assume the output to always be a float64 vector
        outputs = [pt.vector(dtype="float64")]#,[pt.vector(dtype="float64")]]#,pt.vector(dtype="float64")])]#,pt.vector(dtype="float64")
        return Apply(self, inputs, outputs)#[0]),Apply(self, inputs, outputs[1])

    def perform(self, node, inputs, outputs):
        (y0,) = inputs
        result = jitted_sol_op_jax(y0)
        
        outputs[0][0] = np.asarray(result, dtype="float64")
        #print('ciao',outputs[0][1])
        
    def grad(self, inputs, output_gradients):
        (y0,) = inputs
        (gz,) = output_gradients
        return [vjp_sol_op(y0, gz)]


class VJPSolOp(Op):
    def make_node(self, y0, gz):
        inputs = [pt.as_tensor_variable(y0), pt.as_tensor_variable(gz)]
        outputs = [inputs[0].type()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (y0, gz) = inputs
        result = jitted_vjp_sol_op_jax(y0, gz)
        outputs[0][0] = np.asarray(result, dtype="float64")
        
sol_op = SolOp()
vjp_sol_op = VJPSolOp()

pytensor.gradient.verify_grad(sol_op,(np.array([0.1, 0.02]),), rng=np.random.default_rng())

@jax_funcify.register(SolOp)
def sol_op_jax_funcify(op, **kwargs):
    return sol_op_jax

@jax_funcify.register(VJPSolOp)
def vjp_sol_op_jax_funcify(op, **kwargs):
    return vjp_sol_op_jax


with pm.Model() as model:
    y0 = pm.Normal("y0")
    y1 = pm.Normal('y1')
    
    ys = sol_op(pm.math.stack([y0,y1]))
    noise = pm.HalfNormal("noise")
    llike = pm.Normal("llike", ys, noise, observed=noisy_array)


pm.model_to_graphviz(model)

ip = model.initial_point()
logp_fn = model.compile_fn(model.logp(sum=False))
logp_fn(ip)

logp_fn = model.compile_fn(model.logp(sum=False), mode="JAX")
logp_fn(ip)

dlogp_fn = model.compile_fn(model.dlogp())
dlogp_fn(ip)

dlogp_fn = model.compile_fn(model.dlogp(), mode="JAX")
dlogp_fn(ip)

sampler = "SMC with Likelihood"
print('start '+sampler)
with model:                              
    trace_pymc_ode = pm.sample_smc(draws=4000,cores=4)

@POde97 I don’t quite get what is the problem you are facing?

You have an Op with multiple outputs and verify_grad can’t handle those? If you just want to see if the outputs are correct you can create an Op that outputs only one set of outputs at a time.

It’s also fine to not have gradients defined for all the inputs/outputs of an Op, you can use one of:

https://pytensor.readthedocs.io/en/latest/library/gradient.html#pytensor.gradient.grad_not_implemented
https://pytensor.readthedocs.io/en/latest/library/gradient.html#pytensor.gradient.grad_undefined

Hey Ricardo thank you for the fast answer. I’m trying to generalise the op for the following 2 dimensional system but I’m not understanding how to do it properly. I attach my code perhaps it is easier to understand this way. Thank you!

import diffrax
import matplotlib.pyplot as plt
import numpy as np

import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify

import jax
import jax.numpy as jnp

import pymc as pm
import pymc.sampling.jax



vector_field = lambda t, y, args: -args*y




def rhs(t,y, p):
    S, P =  y
    vmax, K_S =  p
    dPdt = vmax * (S / K_S + S)
    dSdt = -dPdt

    dy = dSdt,dPdt
    return dy

term = diffrax.ODETerm(rhs)
solver = diffrax.Dopri5()
saveat = diffrax.SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)

sol = diffrax.diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=(1,1), saveat=saveat,
                  stepsize_controller=stepsize_controller,args=(2.0,1.0))

print(sol.ts)
print(sol.ys)


def sol_op_jax(y0):
    print(y0[0],y0[1])
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0=0,
        t1=55,
        dt0=0.1,
        y0=(y0[0],y0[1]),
        args =(y0[2],y0[3]),
        saveat=saveat,
        stepsize_controller=stepsize_controller
    )

    return sol.ys

jitted_sol_op_jax = jax.jit(sol_op_jax)
def vjp_sol_op_jax(y0, gz):
    _, vjp_fn = jax.vjp(sol_op_jax, y0)
    
    return vjp_fn(gz)[0],vjp_fn(gz)[1]

jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax)

class SolOp(Op):
    def make_node(self, y0):
        inputs = [pt.as_tensor_variable(y0)]
        # Assume the output to always be a float64 vector
        outputs = [pt.vector(dtype="float64")]#,[pt.vector(dtype="float64")]]#,pt.vector(dtype="float64")])]#,pt.vector(dtype="float64")
        return Apply(self, inputs, outputs)#[0]),Apply(self, inputs, outputs[1])

    def perform(self, node, inputs, outputs):
        (y0,) = inputs
        result = jitted_sol_op_jax(y0)
        
    
        outputs[0][0] = np.asarray(result[0], dtype="float64")
        
        outputs[1][0] = np.asarray(result[1], dtype="float64")

    def grad(self, inputs, output_gradients):
        (y0,) = inputs
        (gz,) = output_gradients
        return [vjp_sol_op(y0, gz)]


class VJPSolOp(Op):
    def make_node(self, y0, gz):
        inputs = [pt.as_tensor_variable(y0), pt.as_tensor_variable(gz)]
        outputs = [inputs[0].type()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (y0, gz) = inputs
        result = jitted_vjp_sol_op_jax(y0, gz)
        outputs[0][0] = np.asarray(result[0], dtype="float64")
        print('there',outputs[0][0])
        outputs[1][0] = np.asarray(result[1], dtype="float64")

sol_op = SolOp()
vjp_sol_op = VJPSolOp()

pytensor.gradient.verify_grad(sol_op,(np.array([0.1, 0.02,0.1,0.1]),), rng=np.random.default_rng())

—ERROR------

IndexError: list index out of range
Apply node that caused the error: SolOp(input 0)
Toposort index: 0
Inputs types: [TensorType(float64, shape=(4,))]
Inputs shapes: [(4,)]
Inputs strides: [(8,)]
Inputs values: [array([0.1 , 0.02, 0.1 , 0.1 ])]
Outputs clients: [['output']]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/home/pymc_env/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/home/pymc_env/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/home/pymc_env/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/home/pymc_env/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_6714/910888793.py", line 106, in <module>
    pytensor.gradient.verify_grad(sol_op,(np.array([0.1, 0.02,0.1,0.1]),), rng=np.random.default_rng())
  File "/home/pymc_env/lib/python3.11/site-packages/pytensor/gradient.py", line 1796, in verify_grad
    o_output = fun(*tensor_pt)
  File "/home/pymc_env/lib/python3.11/site-packages/pytensor/graph/op.py", line 295, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/tmp/ipykernel_6714/910888793.py", line 72, in make_node
    outputs = [pt.vector(dtype="float64")]#,[pt.vector(dtype="float64")]]#,pt.vector(dtype="float64")])]#,pt.vector(dtype="float64")

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

Hi @nlinden, did you managed to do the code? Can you share it with me?

Hi @Elia97, yes, I did get it to work with ODEs with multiple input parameters and outputs. I implemented the whole thing in a custom class to make it easier to reuse in different projects.

My custom PyTensor Op is:

class SolOp(Op):
    def __init__(self, sol_op_jax_jitted, vjp_sol_op):
        self.sol_op_jax_jitted = sol_op_jax_jitted
        self.vjp_sol_op = vjp_sol_op

    def make_node(self, *inputs):
        # Convert our inputs to symbolic variables
        inputs = [pt.as_tensor_variable(inp) for inp in inputs]
        # Assume the output to always be a float64 matrix
        outputs = [pt.matrix()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        result = self.sol_op_jax_jitted(*inputs)
        outputs[0][0] = np.asarray(result, dtype="float64")
        
    def grad(self, inputs, output_grads):
        (gz,) = output_grads
        return self.vjp_sol_op(inputs, gz)
    
class VJPSolOp(Op):
    def __init__(self, vjp_sol_op_jax_jitted):
        self.vjp_sol_op_jax_jitted = vjp_sol_op_jax_jitted

    def make_node(self, inputs, gz):
        inputs = [pt.as_tensor_variable(inp) for inp in inputs] 
        inputs += [pt.as_tensor_variable(gz)]
        outputs = [inputs[i].type() for i in range(len(inputs)-1)]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        *params, gz = inputs
        result = self.vjp_sol_op_jax_jitted(gz, *params)
        for i, res in enumerate(result):
            outputs[i][0] = np.asarray(res, dtype="float64")

This is used for loops, so it’s not fully Jax-compatible. I would appreciate any suggestions and improvements!

Hope this helps.