PyMC Graph accepting gradients?

So my goal in creating the previous PyTensor Op was to feed the analytical gradient made from jitted_grad_ecc into PyMC’s wider graph as a sort of “differentiation rule” for only this specific this jitted_ecc function while sampling from numpyro

Because I noticed that that wasn’t happening, would adding something like

logp_fn = model.compile_fn(model.logp(sum=False), mode='JAX')

after the model is stated work?

I have checked out How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs and te link you sent before and I think that should work? Would I be missing anything?