PyMC Graph accepting gradients?

I am trying to understand how the wider PyMC graph accepts custom gradients from PyTensor Ops. I have my own PyTensor Op to calculate gradients, but I have realized that the wider PyMC model I have isn’t accepting them (or doesn’t recognize them).

In this example can someone explain how exactly the gradients calculated get sent to the wider PyMC model? I assumed it was in this line:

return [
            pt.sum(out_grad * grad_wrt_m),
            pt.sum(out_grad * grad_wrt_c),
            # We did not implement gradients wrt to the last 3 inputs
            # This won't be a problem for sampling, as those are constants in our model
            pytensor.gradient.grad_not_implemented(self, 2, sigma),
            pytensor.gradient.grad_not_implemented(self, 3, x),
            pytensor.gradient.grad_not_implemented(self, 4, data),
        ]

since I do not understand this line:

[out_grad] = g

For added context my full Op is:

def jax_ecc_anom(manom, ecc):
    # manom = np.reshape(manom,[7])
    # print(type(ecc))
    alpha = (1.0 - ecc) / ((4.0 * ecc) + 0.5)
    beta = (0.5 * manom) / ((4.0 * ecc) + 0.5)

    aux = jnp.sqrt(beta**2.0 + alpha**3.0) 
    z = jnp.abs(beta + aux)**(1.0/3.0) 

    s0 = z - (alpha/z)
    s1 = s0 - (0.078*(s0**5.0)) / (1.0 + ecc)
    e0 = manom + (ecc * (3.0*s1 - 4.0*(s1**3.0)))

    se0 = jnp.sin(e0)
    ce0 = jnp.cos(e0)

    f = e0-ecc*se0-manom

    f1 = 1.0-ecc*ce0
    f2 = ecc*se0
    f3 = ecc*ce0
    f4 = -f2
    u1 = -f/f1
    u2 = -f/(f1+0.5*f2*u1)
    u3 = -f/(f1+0.5*f2*u2+(1.0/6.0)*f3*u2*u2)
    u4 = -f/(f1+0.5*f2*u3+(1.0/6.0)*f3*u3*u3+(1.0/24.0)*f4*(u3**3.0))

    return (e0 + u4)


jitted_ecc = jax.jit(jax_ecc_anom)

def grad_ecc(EA, ecc):

    sea = jnp.sin(EA)
    cea = jnp.cos(EA)
    temp = 1 - ecc * cea
    dEA_dM = 1.0 / temp
    dEA_de = (sea * EA) / temp
    return dEA_dM, dEA_de

jitted_grad_ecc = jax.jit(grad_ecc)


class KeplerSolverOp(Op):
    __props__ = ()

    def make_node(self, M, e):
        M = at.as_tensor(M)
        e = at.as_tensor_variable(e)

        inputs = [M, e]
        outputs = [M.type()]

        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        manom, ecc = inputs

        EA = jitted_ecc(manom, ecc)  # Compute ecc-anomly

        outputs[0][0] = np.asarray(EA)


    def grad(self, inputs, out_grad):
        M, e = inputs

        grad_wrt_m, grad_wrt_e = kepler_loglikegrad(M,e)

        return[grad_wrt_m,grad_wrt_e]
    
class LogLikeGrad(Op):
    def make_node(self, M, e):
        M = at.as_tensor(M)
        e = at.as_tensor_variable(e)

        inputs = [M, e]

        outputs = [M.type(), e.type()]

        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        manom, ecc = inputs

        EA = jitted_ecc(manom, ecc)

        dEA_dM, dEA_de = jitted_grad_ecc(EA,ecc)
        
        outputs[0][0] = dEA_dM
        outputs[1][0] = dEA_de

        # partials(M)/temp + partials(e)*sea/temp
        # dEA_dM, dEA_de, temp, sea = jitted_grad_ecc(EA,ecc)
        # outputs[0][0] = (dEA_dM/temp) + (dEA_de*sea)/(temp)

keplergrad = KeplerSolverOp()
kepler_loglikegrad = LogLikeGrad()


@jax_funcify.register(KeplerSolverOp) #done
def keplergrad_jaxify(op, **kwargs):#done
    return jax_ecc_anom#done

@jax_funcify.register(LogLikeGrad) #done
def kepler_loglikegrad_jaxify(op, **kwargs):#done
    return grad_ecc

That line is doing unpacking from a list with a single element.

For a single element you need to distinguish between the two cases

x = [1]  # x is [1]
[x] = [1]  # x is 1
# same as (x,) = [1]
assert x == 1

You can also unpack multiple elements, which is commonly seen in loops, function calls

[x, y] = [1, 2]
# same as x, y = [1, 2]
assert x + y == 3

Anyway, Python syntax aside, PyMC requests your gradient expressions when you (or PyMC) asks for the gradient of the logp. You can trigger this manually with model.dlogp(). Note that if you are using an external JAX sampler like numpyro or blackjax the gradient won’t be provided by PyTensor but instead by JAX, and it doesn’t matter whether you implemented it or not in PyTensor.

For a more intro picture of how PyMC gets to probabilites (and then its gradients): PyMC and PyTensor — PyMC 5.16.2 documentation

So my goal in creating the previous PyTensor Op was to feed the analytical gradient made from jitted_grad_ecc into PyMC’s wider graph as a sort of “differentiation rule” for only this specific this jitted_ecc function while sampling from numpyro

Because I noticed that that wasn’t happening, would adding something like

logp_fn = model.compile_fn(model.logp(sum=False), mode='JAX')

after the model is stated work?

I have checked out How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs and te link you sent before and I think that should work? Would I be missing anything?

PyMC/PyTensor are not involved in the gradients when using numpyro. If you want to overload the gradients numpyro uses you’ll have to do it with jax custom gradient API: Custom derivative rules for JAX-transformable Python functions — JAX documentation

If you notice the last example in the blogpost we don’t bother defining gradients in PyTensor because we’ll sample with numpyro