I am trying to understand how the wider PyMC graph accepts custom gradients from PyTensor Ops. I have my own PyTensor Op to calculate gradients, but I have realized that the wider PyMC model I have isn’t accepting them (or doesn’t recognize them).
In this example can someone explain how exactly the gradients calculated get sent to the wider PyMC model? I assumed it was in this line:
return [
pt.sum(out_grad * grad_wrt_m),
pt.sum(out_grad * grad_wrt_c),
# We did not implement gradients wrt to the last 3 inputs
# This won't be a problem for sampling, as those are constants in our model
pytensor.gradient.grad_not_implemented(self, 2, sigma),
pytensor.gradient.grad_not_implemented(self, 3, x),
pytensor.gradient.grad_not_implemented(self, 4, data),
]
since I do not understand this line:
[out_grad] = g
For added context my full Op is:
def jax_ecc_anom(manom, ecc):
# manom = np.reshape(manom,[7])
# print(type(ecc))
alpha = (1.0 - ecc) / ((4.0 * ecc) + 0.5)
beta = (0.5 * manom) / ((4.0 * ecc) + 0.5)
aux = jnp.sqrt(beta**2.0 + alpha**3.0)
z = jnp.abs(beta + aux)**(1.0/3.0)
s0 = z - (alpha/z)
s1 = s0 - (0.078*(s0**5.0)) / (1.0 + ecc)
e0 = manom + (ecc * (3.0*s1 - 4.0*(s1**3.0)))
se0 = jnp.sin(e0)
ce0 = jnp.cos(e0)
f = e0-ecc*se0-manom
f1 = 1.0-ecc*ce0
f2 = ecc*se0
f3 = ecc*ce0
f4 = -f2
u1 = -f/f1
u2 = -f/(f1+0.5*f2*u1)
u3 = -f/(f1+0.5*f2*u2+(1.0/6.0)*f3*u2*u2)
u4 = -f/(f1+0.5*f2*u3+(1.0/6.0)*f3*u3*u3+(1.0/24.0)*f4*(u3**3.0))
return (e0 + u4)
jitted_ecc = jax.jit(jax_ecc_anom)
def grad_ecc(EA, ecc):
sea = jnp.sin(EA)
cea = jnp.cos(EA)
temp = 1 - ecc * cea
dEA_dM = 1.0 / temp
dEA_de = (sea * EA) / temp
return dEA_dM, dEA_de
jitted_grad_ecc = jax.jit(grad_ecc)
class KeplerSolverOp(Op):
__props__ = ()
def make_node(self, M, e):
M = at.as_tensor(M)
e = at.as_tensor_variable(e)
inputs = [M, e]
outputs = [M.type()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
manom, ecc = inputs
EA = jitted_ecc(manom, ecc) # Compute ecc-anomly
outputs[0][0] = np.asarray(EA)
def grad(self, inputs, out_grad):
M, e = inputs
grad_wrt_m, grad_wrt_e = kepler_loglikegrad(M,e)
return[grad_wrt_m,grad_wrt_e]
class LogLikeGrad(Op):
def make_node(self, M, e):
M = at.as_tensor(M)
e = at.as_tensor_variable(e)
inputs = [M, e]
outputs = [M.type(), e.type()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
manom, ecc = inputs
EA = jitted_ecc(manom, ecc)
dEA_dM, dEA_de = jitted_grad_ecc(EA,ecc)
outputs[0][0] = dEA_dM
outputs[1][0] = dEA_de
# partials(M)/temp + partials(e)*sea/temp
# dEA_dM, dEA_de, temp, sea = jitted_grad_ecc(EA,ecc)
# outputs[0][0] = (dEA_dM/temp) + (dEA_de*sea)/(temp)
keplergrad = KeplerSolverOp()
kepler_loglikegrad = LogLikeGrad()
@jax_funcify.register(KeplerSolverOp) #done
def keplergrad_jaxify(op, **kwargs):#done
return jax_ecc_anom#done
@jax_funcify.register(LogLikeGrad) #done
def kepler_loglikegrad_jaxify(op, **kwargs):#done
return grad_ecc