# How can I output a gradient in vector format in Op.grad instance?

Hello,
I am new to PyMC. I’ve been using JAX+numpyro, but now I’m going to PyMC 4.0 with JAX. I am trying to wrap a JAX function into `aesara` while reading this page.

I am having trouble specifying the gradient of a vector type output in `Op`. Let me explain.

• environment: PyMC version 4.1.5 / JAX version 0.3.1

I’d like to perform an HMC-NUTS to infer a parameter (`phase`) in a sine function model. `y = sin(x + phase)`. I generated simple mock data as

``````np.random.seed(32)
phase=0.5
sigin=0.3
N=20
x=np.sort(np.random.rand(N))*4*np.pi
y=np.sin(x+phase)+np.random.normal(0,sigin,size=N)
`````` Then, I’d like to fit the JAX-based sine model to the data:

``````import jax.numpy as jnp
import jax
from jax import config

#float64nise
config.update("jax_enable_x64", True)
x = np.array(x,dtype=np.float64)
y = np.array(y,dtype=np.float64)
phase = np.float64(phase)

def sinmodel(x,phase):
return jnp.sin(x + phase)

in_axes=(0,None)))
``````

Then, both `sinmodel` and `grad_sinmodel` provide a 1-dimensional array (JAX DeviceArray) w/ 20 elements. To use the JAX function in PyMC, I need to wrap the above functions by Op.

``````class HMCGradOp(Op):
def make_node(self, x, phase):
inputs = [at.as_tensor_variable(x), at.as_tensor_variable(phase)]
outputs = [at.dvector()]
return Apply(self, inputs, outputs)

def perform(self, node, inputs, outputs):

hmc_grad_op(x,phase).eval() #for check -> it works
``````

Then, I defined Op

``````class HMCOp(Op):
def make_node(self, x, phase):
inputs = [at.as_tensor_variable(x), at.as_tensor_variable(phase)]
outputs = [at.dvector()]
return Apply(self, inputs, outputs)

def perform(self, node, inputs, outputs):
result = sinmodel(*inputs)
outputs = np.asarray(result, dtype=node.outputs.dtype)

(op=self,x_pos=0,x=inputs),
``````

I suspect that the last line, especially `output_gradient * grad_phase` is the reason of the following bug. Next I try to perform the sampling as

``````import jax.numpy as jnp
import pymc

def build_model(x,y):
with pymc.Model() as pmmodel:
phase = pymc.Uniform('phase', lower = -1.0*jnp.pi, upper = 1.0*jnp.pi)
sigma = pymc.Exponential('sigma', lam = 1.)
mu=hmc_op(x,phase)
#mu = at.sin(x+phase) #this works
d=pymc.Normal('y', mu = mu, sigma = sigma, observed=y)
return pmmodel
``````
``````model_pymc = build_model(x,y)
with model_pymc:
idata = pymc.sample(return_inferencedata=True)
``````

Then I got ValueError

``````ValueError: HMCOp.grad returned a term with 1 dimensions, but 0 are required.
``````

Prior to the sampling, a quick check of `grad`

``````hmc_op = HMCOp()
at_phase = at.as_tensor_variable(phase)
yt = hmc_op(x,at_phase)
``````

also resulted in the same error.

So, I suspect that I got this error because `output_gradient * grad_phase` is 1 dimension, not scalar. This can be checked by using

``````       return  [aesara.gradient.grad_not_implemented
(op=self,x_pos=0,x=inputs),
``````

instead. In this case, I was able to run an HMC-NUTS at least formally (but of course failed the initialization because wrong derivative)

``````Auto-assigning NUTS sampler...
INFO:pymc:Auto-assigning NUTS sampler...
``````

So, I think I need to know how to specify the vector-type gradient information in `Op`, but I have no idea how I can. Does anyone have any idea how to resolve this problem?

Did you have a look at this? How to wrap a JAX function for use in PyMC — PyMC example gallery

1 Like

Yes! I started from that page. But, the difference is that the example on the page uses a scalar output as each element in return of `grad` as

``````    def grad(self, inputs, output_gradients):
(
# If there are inputs for which the gradients will never be needed or cannot
# output gradient for that input.
return [
]
``````

In my case, I’d like to return `[scalar, vector]` instead.

Perhaps have a look at the gradient of some vector output `Op`s in Aesara, like `Softmax`? aesara/basic.py at 9f176da71d41635e8854fd601fd6a68102b0c6e5 · aesara-devs/aesara · GitHub

1 Like

Thanks for sharing this example. I will take a look and compare it to my code.

Hi - I’m encountering a similar problem in generalizing an Op to take vector arguments instead of scalars. I was wondering if you figured out what the problem was in the end. Thanks!

If you are familiar with `JAX`, the `grad` method (or even better the `L_op`) is very similar to the `vjp`: There are some examples of how they match here: How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs

Otherwise, if you can share the exact `Op` you are trying to implement, we can provide more specific guidance.

I added a description of the Op I’m trying to write here - any help very appreciated! I am still not totally clear on some detail, so it would be great to have someone with more experience have a look. If we find a way of making it work well it might be worth adding it to the pymc API? Integrations with no closed-form solution show up pretty regularly in my life : )