So my goal in creating the previous PyTensor Op was to feed the analytical gradient made from jitted_grad_ecc into PyMC’s wider graph as a sort of “differentiation rule” for only this specific this jitted_ecc function while sampling from numpyro
Because I noticed that that wasn’t happening, would adding something like
logp_fn = model.compile_fn(model.logp(sum=False), mode='JAX')
after the model is stated work?
I have checked out How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs and te link you sent before and I think that should work? Would I be missing anything?