How to write a PyTensor Op to wrap Jax ODEs with multiple input parameters

Hi Ricardo and all pymc community :smiley:

I’m still wondering how it’s possible to store something in outputs[1][0] because then verify_grad() do not support multiple outputs.

As JB suggested ‘we could make loop over outputs making random projections R for each,
but this doesn’t handle the case where not all the outputs are differentiable… so I leave this as TODO for now -JB.’
So the foor loop is not optimal. Do you have any idea on how to do that ? I would like to contribute, ( because for a system of coupled equation seems to be mandatory to store multiple outputs (am I correct ? ) ) but I cannot see any better idea.

For the one dimensional examples with one parameter (above) this is the solution by the way

import diffrax
import matplotlib.pyplot as plt
import numpy as np

import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify
import arviz as az
import jax
import jax.numpy as jnp

import pymc as pm
import pymc.sampling.jax


vector_field = lambda t, y, args: -args*y

term = diffrax.ODETerm(vector_field)
solver = diffrax.Dopri5()
saveat = diffrax.SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)

sol = diffrax.diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
                  stepsize_controller=stepsize_controller,args=2.0)


noise = np.random.normal(0, 0.1, sol.ys.shape)
noisy_array = sol.ys + noise


def sol_op_jax(y0):
    
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0=0,
        t1=55,
        dt0=0.1,
        y0=y0[0],
        args =y0[1],
        saveat=saveat,
        stepsize_controller=stepsize_controller
    )

    return sol.ys

jitted_sol_op_jax = jax.jit(sol_op_jax)

def vjp_sol_op_jax(y0, gz):
    _, vjp_fn = jax.vjp(sol_op_jax, y0)
    
    return vjp_fn(gz)[0]

jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax)


class SolOp(Op):
    def make_node(self, y0):
        inputs = [pt.as_tensor_variable(y0)]
        # Assume the output to always be a float64 vector
        outputs = [pt.vector(dtype="float64")]#,[pt.vector(dtype="float64")]]#,pt.vector(dtype="float64")])]#,pt.vector(dtype="float64")
        return Apply(self, inputs, outputs)#[0]),Apply(self, inputs, outputs[1])

    def perform(self, node, inputs, outputs):
        (y0,) = inputs
        result = jitted_sol_op_jax(y0)
        
        outputs[0][0] = np.asarray(result, dtype="float64")
        #print('ciao',outputs[0][1])
        
    def grad(self, inputs, output_gradients):
        (y0,) = inputs
        (gz,) = output_gradients
        return [vjp_sol_op(y0, gz)]


class VJPSolOp(Op):
    def make_node(self, y0, gz):
        inputs = [pt.as_tensor_variable(y0), pt.as_tensor_variable(gz)]
        outputs = [inputs[0].type()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (y0, gz) = inputs
        result = jitted_vjp_sol_op_jax(y0, gz)
        outputs[0][0] = np.asarray(result, dtype="float64")
        
sol_op = SolOp()
vjp_sol_op = VJPSolOp()

pytensor.gradient.verify_grad(sol_op,(np.array([0.1, 0.02]),), rng=np.random.default_rng())

@jax_funcify.register(SolOp)
def sol_op_jax_funcify(op, **kwargs):
    return sol_op_jax

@jax_funcify.register(VJPSolOp)
def vjp_sol_op_jax_funcify(op, **kwargs):
    return vjp_sol_op_jax


with pm.Model() as model:
    y0 = pm.Normal("y0")
    y1 = pm.Normal('y1')
    
    ys = sol_op(pm.math.stack([y0,y1]))
    noise = pm.HalfNormal("noise")
    llike = pm.Normal("llike", ys, noise, observed=noisy_array)


pm.model_to_graphviz(model)

ip = model.initial_point()
logp_fn = model.compile_fn(model.logp(sum=False))
logp_fn(ip)

logp_fn = model.compile_fn(model.logp(sum=False), mode="JAX")
logp_fn(ip)

dlogp_fn = model.compile_fn(model.dlogp())
dlogp_fn(ip)

dlogp_fn = model.compile_fn(model.dlogp(), mode="JAX")
dlogp_fn(ip)

sampler = "SMC with Likelihood"
print('start '+sampler)
with model:                              
    trace_pymc_ode = pm.sample_smc(draws=4000,cores=4)