Hi Ricardo and all pymc community ![]()
I’m still wondering how it’s possible to store something in outputs[1][0] because then verify_grad() do not support multiple outputs.
As JB suggested ‘we could make loop over outputs making random projections R for each,
but this doesn’t handle the case where not all the outputs are differentiable… so I leave this as TODO for now -JB.’
So the foor loop is not optimal. Do you have any idea on how to do that ? I would like to contribute, ( because for a system of coupled equation seems to be mandatory to store multiple outputs (am I correct ? ) ) but I cannot see any better idea.
For the one dimensional examples with one parameter (above) this is the solution by the way
import diffrax
import matplotlib.pyplot as plt
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify
import arviz as az
import jax
import jax.numpy as jnp
import pymc as pm
import pymc.sampling.jax
vector_field = lambda t, y, args: -args*y
term = diffrax.ODETerm(vector_field)
solver = diffrax.Dopri5()
saveat = diffrax.SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)
sol = diffrax.diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
stepsize_controller=stepsize_controller,args=2.0)
noise = np.random.normal(0, 0.1, sol.ys.shape)
noisy_array = sol.ys + noise
def sol_op_jax(y0):
sol = diffrax.diffeqsolve(
term,
solver,
t0=0,
t1=55,
dt0=0.1,
y0=y0[0],
args =y0[1],
saveat=saveat,
stepsize_controller=stepsize_controller
)
return sol.ys
jitted_sol_op_jax = jax.jit(sol_op_jax)
def vjp_sol_op_jax(y0, gz):
_, vjp_fn = jax.vjp(sol_op_jax, y0)
return vjp_fn(gz)[0]
jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax)
class SolOp(Op):
def make_node(self, y0):
inputs = [pt.as_tensor_variable(y0)]
# Assume the output to always be a float64 vector
outputs = [pt.vector(dtype="float64")]#,[pt.vector(dtype="float64")]]#,pt.vector(dtype="float64")])]#,pt.vector(dtype="float64")
return Apply(self, inputs, outputs)#[0]),Apply(self, inputs, outputs[1])
def perform(self, node, inputs, outputs):
(y0,) = inputs
result = jitted_sol_op_jax(y0)
outputs[0][0] = np.asarray(result, dtype="float64")
#print('ciao',outputs[0][1])
def grad(self, inputs, output_gradients):
(y0,) = inputs
(gz,) = output_gradients
return [vjp_sol_op(y0, gz)]
class VJPSolOp(Op):
def make_node(self, y0, gz):
inputs = [pt.as_tensor_variable(y0), pt.as_tensor_variable(gz)]
outputs = [inputs[0].type()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
(y0, gz) = inputs
result = jitted_vjp_sol_op_jax(y0, gz)
outputs[0][0] = np.asarray(result, dtype="float64")
sol_op = SolOp()
vjp_sol_op = VJPSolOp()
pytensor.gradient.verify_grad(sol_op,(np.array([0.1, 0.02]),), rng=np.random.default_rng())
@jax_funcify.register(SolOp)
def sol_op_jax_funcify(op, **kwargs):
return sol_op_jax
@jax_funcify.register(VJPSolOp)
def vjp_sol_op_jax_funcify(op, **kwargs):
return vjp_sol_op_jax
with pm.Model() as model:
y0 = pm.Normal("y0")
y1 = pm.Normal('y1')
ys = sol_op(pm.math.stack([y0,y1]))
noise = pm.HalfNormal("noise")
llike = pm.Normal("llike", ys, noise, observed=noisy_array)
pm.model_to_graphviz(model)
ip = model.initial_point()
logp_fn = model.compile_fn(model.logp(sum=False))
logp_fn(ip)
logp_fn = model.compile_fn(model.logp(sum=False), mode="JAX")
logp_fn(ip)
dlogp_fn = model.compile_fn(model.dlogp())
dlogp_fn(ip)
dlogp_fn = model.compile_fn(model.dlogp(), mode="JAX")
dlogp_fn(ip)
sampler = "SMC with Likelihood"
print('start '+sampler)
with model:
trace_pymc_ode = pm.sample_smc(draws=4000,cores=4)