Hello,

I am new to PyMC. I’ve been using JAX+numpyro, but now I’m going to PyMC 4.0 with JAX. I am trying to wrap a JAX function into `aesara`

while reading this page.

I am having trouble specifying the gradient of a vector type output in `Op`

. Let me explain.

- environment: PyMC version 4.1.5 / JAX version 0.3.1

I’d like to perform an HMC-NUTS to infer a parameter (`phase`

) in a sine function model. `y = sin(x + phase)`

. I generated simple mock data as

```
np.random.seed(32)
phase=0.5
sigin=0.3
N=20
x=np.sort(np.random.rand(N))*4*np.pi
y=np.sin(x+phase)+np.random.normal(0,sigin,size=N)
```

Then, I’d like to fit the JAX-based sine model to the data:

```
import jax.numpy as jnp
import jax
from jax import config
#float64nise
config.update("jax_enable_x64", True)
x = np.array(x,dtype=np.float64)
y = np.array(y,dtype=np.float64)
phase = np.float64(phase)
def sinmodel(x,phase):
return jnp.sin(x + phase)
grad_sinmodel = jax.jit(jax.vmap(jax.grad(sinmodel, argnums=(1)),
in_axes=(0,None)))
```

Then, both `sinmodel`

and `grad_sinmodel`

provide a 1-dimensional array (JAX DeviceArray) w/ 20 elements. To use the JAX function in PyMC, I need to wrap the above functions by Op.

```
class HMCGradOp(Op):
def make_node(self, x, phase):
inputs = [at.as_tensor_variable(x), at.as_tensor_variable(phase)]
outputs = [at.dvector()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
grad_phase = grad_sinmodel(*inputs)
outputs[0][0] = np.asarray(grad_phase, dtype=node.outputs[0].dtype)
hmc_grad_op = HMCGradOp()
hmc_grad_op(x,phase).eval() #for check -> it works
```

Then, I defined Op

```
class HMCOp(Op):
def make_node(self, x, phase):
inputs = [at.as_tensor_variable(x), at.as_tensor_variable(phase)]
outputs = [at.dvector()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
result = sinmodel(*inputs)
outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)
def grad(self, inputs, output_gradients):
grad_phase = hmc_grad_op(*inputs)
output_gradient = output_gradients[0]
return [aesara.gradient.grad_not_implemented
(op=self,x_pos=0,x=inputs[0]),
output_gradient * grad_phase]
```

I suspect that the last line, especially `output_gradient * grad_phase`

is the reason of the following bug. Next I try to perform the sampling as

```
import jax.numpy as jnp
import pymc
def build_model(x,y):
with pymc.Model() as pmmodel:
phase = pymc.Uniform('phase', lower = -1.0*jnp.pi, upper = 1.0*jnp.pi)
sigma = pymc.Exponential('sigma', lam = 1.)
mu=hmc_op(x,phase)
#mu = at.sin(x+phase) #this works
d=pymc.Normal('y', mu = mu, sigma = sigma, observed=y)
return pmmodel
```

```
model_pymc = build_model(x,y)
with model_pymc:
idata = pymc.sample(return_inferencedata=True)
```

Then I got ValueError

```
ValueError: HMCOp.grad returned a term with 1 dimensions, but 0 are required.
```

Prior to the sampling, a quick check of `grad`

```
hmc_op = HMCOp()
at_phase = at.as_tensor_variable(phase)
yt = hmc_op(x,at_phase)
yt_grad = at.grad(at.max(yt),wrt=at_phase)
yt_grad.eval()
```

also resulted in the same error.

So, I suspect that I got this error because `output_gradient * grad_phase`

is 1 dimension, not scalar. This can be checked by using

```
return [aesara.gradient.grad_not_implemented
(op=self,x_pos=0,x=inputs[0]),
output_gradient[1] * at.sum(grad_phase)] #passed, dimension matched
```

instead. In this case, I was able to run an HMC-NUTS at least formally (but of course failed the initialization because wrong derivative)

```
Auto-assigning NUTS sampler...
INFO:pymc:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc:Initializing NUTS using jitter+adapt_diag...
```

So, I think I need to know how to specify the vector-type gradient information in `Op`

, but I have no idea how I can. Does anyone have any idea how to resolve this problem?

Thanks in advance.