I am new to PyMC and Jax. I am trying to use PyMC to estimate the parameters for ODE implemented in Jax. The ODEs I work with often have many input parameters so I would like to write a PyTensor Op that can handle this. I have successfully used sunode, but there are some features from Jax and diffrax that I would like to use.
I am trying to modify the Ops written in the tutorial https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/ to handle multiple parameters. However, I am running into several errors which I assume are related to problems in my implementation.
Here is the code I have written. In the tutorial, only the initial condition y0
was estimated, but I added an additional parameter args
to infer as well.
vector_field = lambda t, y, args: -args*y
term = diffrax.ODETerm(vector_field)
solver = diffrax.Dopri5()
saveat = diffrax.SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)
sol = diffrax.diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
stepsize_controller=stepsize_controller,args=2.0)
print(sol.ts)
print(sol.ys)
def sol_op_jax(y0, args):
sol = diffrax.diffeqsolve(
term,
solver,
t0=0,
t1=3,
dt0=0.1,
y0=y0,
saveat=saveat,
stepsize_controller=stepsize_controller,
args=args,
)
return sol.ys
jitted_sol_op_jax = jax.jit(sol_op_jax)
def vjp_sol_op_jax(inputs, output_grads):
_, vjp_fn = jax.vjp(sol_op_jax, inputs)
return vjp_fn(output_grads)
jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax)
class SolOp(Op):
def make_node(self, *inputs):
inputs = [pt.as_tensor_variable(inp) for inp in inputs]
# Assume the output to always be a float64 vector
outputs = [pt.dvector()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
(y0, args,) = inputs
result = jitted_sol_op_jax(y0, args)
outputs[0][0] = np.asarray(result, dtype="float64")
def grad(self, inputs, output_gradients):
(y0, args) = inputs
(gz,) = output_gradients
return vjp_sol_op(y0, args, gz)
class VJPSolOp(Op):
def make_node(self, y0, args, output_grads):
inputs = [pt.as_tensor_variable(y0), pt.as_tensor_variable(args)]
inputs.append(pt.as_tensor_variable(output_grads))
outputs = [inputs[0].type(), inputs[1].type()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
(y0, args, output_grads) = inputs
result = jitted_vjp_sol_op_jax((y0, args), output_grads)
outputs[0][0] = np.asarray(result, dtype="float64")
sol_op = SolOp()
vjp_sol_op = VJPSolOp()
pytensor.gradient.verify_grad(sol_op, (np.array(3.0), np.array(2.0)), rng=np.random.default_rng())
I am currently getting the following error when I run this:
TypeError: sol_op_jax() missing 1 required positional argument: 'args'
Apply node that caused the error: VJPSolOp(input 0, input 1, random_projection)
Toposort index: 0
Inputs types: [TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (?,))]
Inputs shapes: [(), (), (4,)]
Inputs strides: [(), (), (8,)]
Inputs values: [array(3.), array(2.), array([1.18777319, 0.55691297, 1.30661623, 1.09379087])]
Outputs clients: [['output'], ['output']]
Ideally, I would like to be able to pass a list of tuple of parameters to sol_op_jax()
. Happy to add the full stack trace or any more info if that helps.
Thanks!