I am trying to understand how the wider PyMC graph accepts custom gradients from PyTensor Ops. I have my own PyTensor Op to calculate gradients, but I have realized that the wider PyMC model I have isn’t accepting them (or doesn’t recognize them).

In this example can someone explain how exactly the gradients calculated get sent to the wider PyMC model? I assumed it was in this line:

``````return [
# We did not implement gradients wrt to the last 3 inputs
# This won't be a problem for sampling, as those are constants in our model
]
``````

since I do not understand this line:

``````[out_grad] = g
``````

For added context my full Op is:

``````def jax_ecc_anom(manom, ecc):
# manom = np.reshape(manom,[7])
# print(type(ecc))
alpha = (1.0 - ecc) / ((4.0 * ecc) + 0.5)
beta = (0.5 * manom) / ((4.0 * ecc) + 0.5)

aux = jnp.sqrt(beta**2.0 + alpha**3.0)
z = jnp.abs(beta + aux)**(1.0/3.0)

s0 = z - (alpha/z)
s1 = s0 - (0.078*(s0**5.0)) / (1.0 + ecc)
e0 = manom + (ecc * (3.0*s1 - 4.0*(s1**3.0)))

se0 = jnp.sin(e0)
ce0 = jnp.cos(e0)

f = e0-ecc*se0-manom

f1 = 1.0-ecc*ce0
f2 = ecc*se0
f3 = ecc*ce0
f4 = -f2
u1 = -f/f1
u2 = -f/(f1+0.5*f2*u1)
u3 = -f/(f1+0.5*f2*u2+(1.0/6.0)*f3*u2*u2)
u4 = -f/(f1+0.5*f2*u3+(1.0/6.0)*f3*u3*u3+(1.0/24.0)*f4*(u3**3.0))

return (e0 + u4)

jitted_ecc = jax.jit(jax_ecc_anom)

sea = jnp.sin(EA)
cea = jnp.cos(EA)
temp = 1 - ecc * cea
dEA_dM = 1.0 / temp
dEA_de = (sea * EA) / temp
return dEA_dM, dEA_de

class KeplerSolverOp(Op):
__props__ = ()

def make_node(self, M, e):
M = at.as_tensor(M)
e = at.as_tensor_variable(e)

inputs = [M, e]
outputs = [M.type()]

return Apply(self, inputs, outputs)

def perform(self, node, inputs, outputs):
manom, ecc = inputs

EA = jitted_ecc(manom, ecc)  # Compute ecc-anomly

outputs[0][0] = np.asarray(EA)

M, e = inputs

def make_node(self, M, e):
M = at.as_tensor(M)
e = at.as_tensor_variable(e)

inputs = [M, e]

outputs = [M.type(), e.type()]

return Apply(self, inputs, outputs)

def perform(self, node, inputs, outputs):
manom, ecc = inputs

EA = jitted_ecc(manom, ecc)

outputs[0][0] = dEA_dM
outputs[1][0] = dEA_de

# partials(M)/temp + partials(e)*sea/temp
# dEA_dM, dEA_de, temp, sea = jitted_grad_ecc(EA,ecc)
# outputs[0][0] = (dEA_dM/temp) + (dEA_de*sea)/(temp)

@jax_funcify.register(KeplerSolverOp) #done
return jax_ecc_anom#done

``````

That line is doing unpacking from a list with a single element.

For a single element you need to distinguish between the two cases

``````x = [1]  # x is [1]
[x] = [1]  # x is 1
# same as (x,) = [1]
assert x == 1
``````

You can also unpack multiple elements, which is commonly seen in loops, function calls

``````[x, y] = [1, 2]
# same as x, y = [1, 2]
assert x + y == 3
``````

Anyway, Python syntax aside, PyMC requests your gradient expressions when you (or PyMC) asks for the gradient of the logp. You can trigger this manually with `model.dlogp()`. Note that if you are using an external JAX sampler like `numpyro` or `blackjax` the gradient won’t be provided by PyTensor but instead by JAX, and it doesn’t matter whether you implemented it or not in PyTensor.

For a more intro picture of how PyMC gets to probabilites (and then its gradients): PyMC and PyTensor — PyMC 5.16.2 documentation

So my goal in creating the previous PyTensor Op was to feed the analytical gradient made from `jitted_grad_ecc` into PyMC’s wider graph as a sort of “differentiation rule” for only this specific this `jitted_ecc` function while sampling from `numpyro`

Because I noticed that that wasn’t happening, would adding something like

``````logp_fn = model.compile_fn(model.logp(sum=False), mode='JAX')
``````

after the model is stated work?

I have checked out How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs and te link you sent before and I think that should work? Would I be missing anything?

PyMC/PyTensor are not involved in the gradients when using numpyro. If you want to overload the gradients numpyro uses you’ll have to do it with jax custom gradient API: Custom derivative rules for JAX-transformable Python functions — JAX documentation

If you notice the last example in the blogpost we don’t bother defining gradients in PyTensor because we’ll sample with numpyro