Hey Ricardo thank you for the fast answer. I’m trying to generalise the op for the following 2 dimensional system but I’m not understanding how to do it properly. I attach my code perhaps it is easier to understand this way. Thank you!
import diffrax
import matplotlib.pyplot as plt
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify
import jax
import jax.numpy as jnp
import pymc as pm
import pymc.sampling.jax
vector_field = lambda t, y, args: -args*y
def rhs(t,y, p):
S, P = y
vmax, K_S = p
dPdt = vmax * (S / K_S + S)
dSdt = -dPdt
dy = dSdt,dPdt
return dy
term = diffrax.ODETerm(rhs)
solver = diffrax.Dopri5()
saveat = diffrax.SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)
sol = diffrax.diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=(1,1), saveat=saveat,
stepsize_controller=stepsize_controller,args=(2.0,1.0))
print(sol.ts)
print(sol.ys)
def sol_op_jax(y0):
print(y0[0],y0[1])
sol = diffrax.diffeqsolve(
term,
solver,
t0=0,
t1=55,
dt0=0.1,
y0=(y0[0],y0[1]),
args =(y0[2],y0[3]),
saveat=saveat,
stepsize_controller=stepsize_controller
)
return sol.ys
jitted_sol_op_jax = jax.jit(sol_op_jax)
def vjp_sol_op_jax(y0, gz):
_, vjp_fn = jax.vjp(sol_op_jax, y0)
return vjp_fn(gz)[0],vjp_fn(gz)[1]
jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax)
class SolOp(Op):
def make_node(self, y0):
inputs = [pt.as_tensor_variable(y0)]
# Assume the output to always be a float64 vector
outputs = [pt.vector(dtype="float64")]#,[pt.vector(dtype="float64")]]#,pt.vector(dtype="float64")])]#,pt.vector(dtype="float64")
return Apply(self, inputs, outputs)#[0]),Apply(self, inputs, outputs[1])
def perform(self, node, inputs, outputs):
(y0,) = inputs
result = jitted_sol_op_jax(y0)
outputs[0][0] = np.asarray(result[0], dtype="float64")
outputs[1][0] = np.asarray(result[1], dtype="float64")
def grad(self, inputs, output_gradients):
(y0,) = inputs
(gz,) = output_gradients
return [vjp_sol_op(y0, gz)]
class VJPSolOp(Op):
def make_node(self, y0, gz):
inputs = [pt.as_tensor_variable(y0), pt.as_tensor_variable(gz)]
outputs = [inputs[0].type()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
(y0, gz) = inputs
result = jitted_vjp_sol_op_jax(y0, gz)
outputs[0][0] = np.asarray(result[0], dtype="float64")
print('there',outputs[0][0])
outputs[1][0] = np.asarray(result[1], dtype="float64")
sol_op = SolOp()
vjp_sol_op = VJPSolOp()
pytensor.gradient.verify_grad(sol_op,(np.array([0.1, 0.02,0.1,0.1]),), rng=np.random.default_rng())
—ERROR------
IndexError: list index out of range
Apply node that caused the error: SolOp(input 0)
Toposort index: 0
Inputs types: [TensorType(float64, shape=(4,))]
Inputs shapes: [(4,)]
Inputs strides: [(8,)]
Inputs values: [array([0.1 , 0.02, 0.1 , 0.1 ])]
Outputs clients: [['output']]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/home/pymc_env/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/home/pymc_env/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/home/pymc_env/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/home/pymc_env/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_6714/910888793.py", line 106, in <module>
pytensor.gradient.verify_grad(sol_op,(np.array([0.1, 0.02,0.1,0.1]),), rng=np.random.default_rng())
File "/home/pymc_env/lib/python3.11/site-packages/pytensor/gradient.py", line 1796, in verify_grad
o_output = fun(*tensor_pt)
File "/home/pymc_env/lib/python3.11/site-packages/pytensor/graph/op.py", line 295, in __call__
node = self.make_node(*inputs, **kwargs)
File "/tmp/ipykernel_6714/910888793.py", line 72, in make_node
outputs = [pt.vector(dtype="float64")]#,[pt.vector(dtype="float64")]]#,pt.vector(dtype="float64")])]#,pt.vector(dtype="float64")
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.