How can I output a gradient in vector format in Op.grad instance?

Did you have a look at this? How to wrap a JAX function for use in PyMC — PyMC example gallery

1 Like