Hello,
I am new to PyMC. I’ve been using JAX+numpyro, but now I’m going to PyMC 4.0 with JAX. I am trying to wrap a JAX function into aesara
while reading this page.
I am having trouble specifying the gradient of a vector type output in Op
. Let me explain.
- environment: PyMC version 4.1.5 / JAX version 0.3.1
I’d like to perform an HMC-NUTS to infer a parameter (phase
) in a sine function model. y = sin(x + phase)
. I generated simple mock data as
np.random.seed(32)
phase=0.5
sigin=0.3
N=20
x=np.sort(np.random.rand(N))*4*np.pi
y=np.sin(x+phase)+np.random.normal(0,sigin,size=N)
Then, I’d like to fit the JAX-based sine model to the data:
import jax.numpy as jnp
import jax
from jax import config
#float64nise
config.update("jax_enable_x64", True)
x = np.array(x,dtype=np.float64)
y = np.array(y,dtype=np.float64)
phase = np.float64(phase)
def sinmodel(x,phase):
return jnp.sin(x + phase)
grad_sinmodel = jax.jit(jax.vmap(jax.grad(sinmodel, argnums=(1)),
in_axes=(0,None)))
Then, both sinmodel
and grad_sinmodel
provide a 1-dimensional array (JAX DeviceArray) w/ 20 elements. To use the JAX function in PyMC, I need to wrap the above functions by Op.
class HMCGradOp(Op):
def make_node(self, x, phase):
inputs = [at.as_tensor_variable(x), at.as_tensor_variable(phase)]
outputs = [at.dvector()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
grad_phase = grad_sinmodel(*inputs)
outputs[0][0] = np.asarray(grad_phase, dtype=node.outputs[0].dtype)
hmc_grad_op = HMCGradOp()
hmc_grad_op(x,phase).eval() #for check -> it works
Then, I defined Op
class HMCOp(Op):
def make_node(self, x, phase):
inputs = [at.as_tensor_variable(x), at.as_tensor_variable(phase)]
outputs = [at.dvector()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
result = sinmodel(*inputs)
outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)
def grad(self, inputs, output_gradients):
grad_phase = hmc_grad_op(*inputs)
output_gradient = output_gradients[0]
return [aesara.gradient.grad_not_implemented
(op=self,x_pos=0,x=inputs[0]),
output_gradient * grad_phase]
I suspect that the last line, especially output_gradient * grad_phase
is the reason of the following bug. Next I try to perform the sampling as
import jax.numpy as jnp
import pymc
def build_model(x,y):
with pymc.Model() as pmmodel:
phase = pymc.Uniform('phase', lower = -1.0*jnp.pi, upper = 1.0*jnp.pi)
sigma = pymc.Exponential('sigma', lam = 1.)
mu=hmc_op(x,phase)
#mu = at.sin(x+phase) #this works
d=pymc.Normal('y', mu = mu, sigma = sigma, observed=y)
return pmmodel
model_pymc = build_model(x,y)
with model_pymc:
idata = pymc.sample(return_inferencedata=True)
Then I got ValueError
ValueError: HMCOp.grad returned a term with 1 dimensions, but 0 are required.
Prior to the sampling, a quick check of grad
hmc_op = HMCOp()
at_phase = at.as_tensor_variable(phase)
yt = hmc_op(x,at_phase)
yt_grad = at.grad(at.max(yt),wrt=at_phase)
yt_grad.eval()
also resulted in the same error.
So, I suspect that I got this error because output_gradient * grad_phase
is 1 dimension, not scalar. This can be checked by using
return [aesara.gradient.grad_not_implemented
(op=self,x_pos=0,x=inputs[0]),
output_gradient[1] * at.sum(grad_phase)] #passed, dimension matched
instead. In this case, I was able to run an HMC-NUTS at least formally (but of course failed the initialization because wrong derivative)
Auto-assigning NUTS sampler...
INFO:pymc:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc:Initializing NUTS using jitter+adapt_diag...
So, I think I need to know how to specify the vector-type gradient information in Op
, but I have no idea how I can. Does anyone have any idea how to resolve this problem?
Thanks in advance.