How can I output a gradient in vector format in Op.grad instance?

Hello,
I am new to PyMC. I’ve been using JAX+numpyro, but now I’m going to PyMC 4.0 with JAX. I am trying to wrap a JAX function into aesara while reading this page.

I am having trouble specifying the gradient of a vector type output in Op. Let me explain.

  • environment: PyMC version 4.1.5 / JAX version 0.3.1

I’d like to perform an HMC-NUTS to infer a parameter (phase) in a sine function model. y = sin(x + phase). I generated simple mock data as

np.random.seed(32)
phase=0.5
sigin=0.3
N=20
x=np.sort(np.random.rand(N))*4*np.pi
y=np.sin(x+phase)+np.random.normal(0,sigin,size=N)

hmc1

Then, I’d like to fit the JAX-based sine model to the data:

import jax.numpy as jnp
import jax
from jax import config

#float64nise
config.update("jax_enable_x64", True)
x = np.array(x,dtype=np.float64)
y = np.array(y,dtype=np.float64)
phase = np.float64(phase)

def sinmodel(x,phase):
    return jnp.sin(x + phase)

grad_sinmodel = jax.jit(jax.vmap(jax.grad(sinmodel, argnums=(1)),
                in_axes=(0,None)))

Then, both sinmodel and grad_sinmodel provide a 1-dimensional array (JAX DeviceArray) w/ 20 elements. To use the JAX function in PyMC, I need to wrap the above functions by Op.

class HMCGradOp(Op):
    def make_node(self, x, phase):
        inputs = [at.as_tensor_variable(x), at.as_tensor_variable(phase)]
        outputs = [at.dvector()] 
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        grad_phase = grad_sinmodel(*inputs)
        outputs[0][0] = np.asarray(grad_phase, dtype=node.outputs[0].dtype)
        
hmc_grad_op = HMCGradOp()
hmc_grad_op(x,phase).eval() #for check -> it works 

Then, I defined Op

class HMCOp(Op):
    def make_node(self, x, phase):
        inputs = [at.as_tensor_variable(x), at.as_tensor_variable(phase)]
        outputs = [at.dvector()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        result = sinmodel(*inputs)
        outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)
        
    def grad(self, inputs, output_gradients):
        grad_phase = hmc_grad_op(*inputs)
        output_gradient = output_gradients[0]
        return [aesara.gradient.grad_not_implemented
                (op=self,x_pos=0,x=inputs[0]), 
                output_gradient * grad_phase]  

I suspect that the last line, especially output_gradient * grad_phase is the reason of the following bug. Next I try to perform the sampling as

import jax.numpy as jnp
import pymc

def build_model(x,y):
    with pymc.Model() as pmmodel:
        phase = pymc.Uniform('phase', lower = -1.0*jnp.pi, upper = 1.0*jnp.pi)
        sigma = pymc.Exponential('sigma', lam = 1.)
        mu=hmc_op(x,phase)
        #mu = at.sin(x+phase) #this works
        d=pymc.Normal('y', mu = mu, sigma = sigma, observed=y)
    return pmmodel
model_pymc = build_model(x,y)
with model_pymc:
    idata = pymc.sample(return_inferencedata=True)

Then I got ValueError

ValueError: HMCOp.grad returned a term with 1 dimensions, but 0 are required.

Prior to the sampling, a quick check of grad

hmc_op = HMCOp()
at_phase = at.as_tensor_variable(phase)
yt = hmc_op(x,at_phase)
yt_grad = at.grad(at.max(yt),wrt=at_phase)
yt_grad.eval()

also resulted in the same error.

So, I suspect that I got this error because output_gradient * grad_phase is 1 dimension, not scalar. This can be checked by using

       return  [aesara.gradient.grad_not_implemented
       (op=self,x_pos=0,x=inputs[0]), 
       output_gradient[1] * at.sum(grad_phase)] #passed, dimension matched

instead. In this case, I was able to run an HMC-NUTS at least formally (but of course failed the initialization because wrong derivative)

Auto-assigning NUTS sampler...
INFO:pymc:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc:Initializing NUTS using jitter+adapt_diag...

So, I think I need to know how to specify the vector-type gradient information in Op, but I have no idea how I can. Does anyone have any idea how to resolve this problem?

Thanks in advance.

Did you have a look at this? How to wrap a JAX function for use in PyMC — PyMC example gallery

1 Like

Yes! I started from that page. But, the difference is that the example on the page uses a scalar output as each element in return of grad as

    def grad(self, inputs, output_gradients):
        (
            grad_wrt_emission_obsered,
            grad_wrt_emission_signal,
            grad_wrt_emission_noise,
            grad_wrt_logp_initial_state,
            grad_wrt_logp_transition,
        ) = hmm_logp_grad_op(*inputs)
        # If there are inputs for which the gradients will never be needed or cannot
        # be computed, `aesara.gradient.grad_not_implemented` should  be used as the
        # output gradient for that input.
        output_gradient = output_gradients[0]
        return [
            output_gradient * grad_wrt_emission_obsered,
            output_gradient * grad_wrt_emission_signal,
            output_gradient * grad_wrt_emission_noise,
            output_gradient * grad_wrt_logp_initial_state,
            output_gradient * grad_wrt_logp_transition,
        ]

In my case, I’d like to return [scalar, vector] instead.

Perhaps have a look at the gradient of some vector output Ops in Aesara, like Softmax? aesara/basic.py at 9f176da71d41635e8854fd601fd6a68102b0c6e5 · aesara-devs/aesara · GitHub

1 Like

Thanks for sharing this example. I will take a look and compare it to my code.

Hi - I’m encountering a similar problem in generalizing an Op to take vector arguments instead of scalars. I was wondering if you figured out what the problem was in the end. Thanks!

If you are familiar with JAX, the grad method (or even better the L_op) is very similar to the vjp: There are some examples of how they match here: How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs

Otherwise, if you can share the exact Op you are trying to implement, we can provide more specific guidance.

I added a description of the Op I’m trying to write here - any help very appreciated! I am still not totally clear on some detail, so it would be great to have someone with more experience have a look. If we find a way of making it work well it might be worth adding it to the pymc API? Integrations with no closed-form solution show up pretty regularly in my life : )