ODEs, PyMC4 and custom likelihood in jax

Hi everyone,

I have a log likelihood function for a large ODE system (10-20 state variables and 60-100 parameters that I have to call for 1000 individuals per iteration) implemented using jax (vmap, odeint) for speed purposes. I was wondering if there’s a straightforward way to implement this custom likelihood (or at least my vmap(odeint() function) into a PyMC4 (or 3) model? In Numpyro, for instance, I can easily add the jax.vmap(jax.odeint(…)) term to the model myself (see, e.g., this example, Example: Predator-Prey Model — NumPyro documentation)

Thanks!

2 Likes

You can wrap a JAX function in a Theano/Aesara Op to be used in conjunction with a PyMC model

Hi @yunus,
so the thing with jax Ops is that you’ll need the entire model to be written as Jax-able Ops for sampling the PyMC v4 model with numpyro. That should be possible AFAIK.
Technically Aesara should be flexible enough to make a function where some parts (“thunks”) are C-compiled from COps and some are jax-jitted from Ops that have Jax implementations. But @ricardoV94 please fact-check me here.

For your likelihood, you could implement a custom Op and register jax implementations for the perform (forward) and grad (backward) so you can do that vmap thing.

@ricardoV94 is there a vmap equivalent for C-compiled Aesara Ops? Or is that not even necessary because it’s already dealt with at the level of the C-compiler?

In any case it should be quite interesting to build your model in PyMC v4 and benchmark sunode/sundials vs. jax odeint!

Thanks a lot for your replies. Sorry I didn’t get around to this yet, it’s on my to-do list. I was thinking I may also try out the Blackjax library. I’ll get back to you here once I’ve tried it out

I am attempting to solve a similar problem…
I have a vehicle model as of now with 5 states and 5 parameters (including the data noise that I sample). The vehicle model can be found here. The vehicle model implemented using JAX and similar to OP, I want to implement a custom likelihood function in pymc V4. Given here is my first attempt at this. However, my chains fail and this is the error I get -

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)% [0/6000 00:00<00:00 Sampling 4 chains, 0 divergences]
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
pymc.parallel_sampling.RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/compile/function/types.py", line 964, in __call__
    self.fn()
ValueError: Expected 1 dimensions input

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/parallel_sampling.py", line 125, in run
    self._start_loop()
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/parallel_sampling.py", line 178, in _start_loop
    point, stats = self._compute_point()
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/parallel_sampling.py", line 203, in _compute_point
    point, stats = self._step_method.step(self._point)
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/step_methods/arraystep.py", line 286, in step
    return super().step(point)
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/step_methods/arraystep.py", line 208, in step
    step_res = self.astep(q)
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/step_methods/hmc/base_hmc.py", line 156, in astep
    start = self.integrator.compute_state(q0, p0)
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/step_methods/hmc/integration.py", line 47, in compute_state
    logp, dlogp = self._logp_dlogp_func(q)
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/model.py", line 408, in __call__
    cost, *grads = self._aesara_function(*grad_vars)
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/compile/function/types.py", line 977, in __call__
    raise_with_op(
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/link/utils.py", line 538, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/compile/function/types.py", line 964, in __call__
    self.fn()
ValueError: Expected 1 dimensions input
Apply node that caused the error: Subtensor{int64}(LoglikeGrad.0, ScalarConstant{4})
Toposort index: 22
Inputs types: [TensorType(float64, (None,)), Scalar(int64)]
Inputs shapes: [(5,), ()]
Inputs strides: ['No strides', ()]
Inputs values: [DeviceArray([ 1.71861555e-02,  1.10738053e+00, -2.40910899e+00,
              1.57400068e+06,  7.75898922e+03], dtype=float64), 4]
Outputs clients: [[Elemwise{Composite{(Switch(i0, (i1 * i2 * i2), i3) + i4 + (i5 * i2))}}(Elemwise{ge,no_inplace}.0, TensorConstant{-1111.1111111111113}, sigmaLat_acc_log___log, TensorConstant{0}, (d__logp/dsigmaVy_log___logprob){1.0}, Subtensor{int64}.0)]]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1387, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1059, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1059, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1387, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1059, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1059, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1387, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1214, in access_term_cache
    input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)

HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
"""

The above exception was the direct cause of the following exception:

ValueError: Expected 1 dimensions input
Apply node that caused the error: Subtensor{int64}(LoglikeGrad.0, ScalarConstant{4})
Toposort index: 22
Inputs types: [TensorType(float64, (None,)), Scalar(int64)]
Inputs shapes: [(5,), ()]
Inputs strides: ['No strides', ()]
Inputs values: [DeviceArray([ 1.71861555e-02,  1.10738053e+00, -2.40910899e+00,
              1.57400068e+06,  7.75898922e+03], dtype=float64), 4]
Outputs clients: [[Elemwise{Composite{(Switch(i0, (i1 * i2 * i2), i3) + i4 + (i5 * i2))}}(Elemwise{ge,no_inplace}.0, TensorConstant{-1111.1111111111113}, sigmaLat_acc_log___log, TensorConstant{0}, (d__logp/dsigmaVy_log___logprob){1.0}, Subtensor{int64}.0)]]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1387, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1059, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1059, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1387, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1059, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1059, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1387, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/gradient.py", line 1214, in access_term_cache
    input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)

HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/hussainmustafa/Desktop/research/tutorials/vd_bi_aes.py", line 168, in <module>
    main()
  File "/Users/hussainmustafa/Desktop/research/tutorials/vd_bi_aes.py", line 156, in main
    idata = pm.sample(ndraws ,tune=nburn,discard_tuned_samples=True,return_inferencedata=True,target_accept = 0.9, cores=4)
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/sampling.py", line 543, in sample
    mtrace = _mp_sample(**sample_args, **parallel_args)
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/sampling.py", line 1470, in _mp_sample
    for draw in sampler:
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/parallel_sampling.py", line 460, in __iter__
    draw = ProcessAdapter.recv_draw(self._active)
  File "/Users/hussainmustafa/opt/anaconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/parallel_sampling.py", line 349, in recv_draw
    raise error from old_error
RuntimeError: Chain 2 failed.

I don’t understand aesera very well to understand what is going on here and am lost as to how to fix this. I think I have to create a JAX op to wrap the JAX function as shown here but I am not sure…

@huzaifg looks like you have a shape problem. This is the important part of the traceback:

ValueError: Expected 1 dimensions input
Apply node that caused the error: Subtensor{int64}(LoglikeGrad.0, ScalarConstant{4})
Toposort index: 22
Inputs types: [TensorType(float64, (None,)), Scalar(int64)]
Inputs shapes: [(5,), ()]

Try to assert the shapes of your variables as you are building up the model: assert tuple(x.eval().shape) == (1,2,3)
I don’t know how to do that in the JAX code though. The PyMC model looks fine, I’d guess the bug is in the jax stuff…
good luck

1 Like

The jax stuff works separately, for example a script like this works perfectly well and I get the gradients I expect.

from vd_bi_mod_aes import vehicle_bi
import jax.numpy as jnp
import scipy.io as sio
import sys
import jax
from jax import grad,vmap,jit
from jax.test_util import check_grads
from jax.experimental.ode import odeint
from jax.random import multivariate_normal
import time
import matplotlib.pyplot as mpl
import timeit
import numpy as np
import aesara
import aesara.tensor as at

#params
Cf = -88000.
Cr = -88000.
Iz = 1000.
sigmaVy = 0.006
sigmaYr = 0.04
theta = [Cf,Cr,Iz,sigmaVy,sigmaYr]


#Initial state
wf = 50./(3.6 * 0.285) #Angular velocity of front wheel
wr = 50./(3.6 * 0.285) #Angular velocity of rear wheel
Vx = 50./3.6 #Longitudanal Velocity
Vy = 0. #Lateral velocity
yr = 0. #Yaw rate
state = jnp.array([Vy,Vx,yr,wf,wr],float)




#Target data
vbdata = sio.loadmat(dataFileName)
lat_vel_o = vbdata['lat_vel'].reshape(-1,)
yaw_rate_o = vbdata['yaw_rate'].reshape(-1,)
lat_vel_o = add_noise(lat_vel_o)
yaw_rate_o = add_noise(yaw_rate_o)

target = jnp.stack([lat_vel_o,yaw_rate_o],axis=-1)


time_o = jnp.asarray(vbdata['tDash'].reshape(-1,),float)

@jit
def loglike(theta,state,time_o,targets):
	mod = odeint(vehicle_bi, state , time_o, theta,rtol=1e-6, atol=1e-5, mxstep=1000)
	sigmas = jnp.array(theta[-(targets.shape[1]):])

	return -jnp.sum(jnp.sum((mod[:,[0,2]] - targets)**2/(2.*sigmas**2))/jnp.linalg.norm(targets,axis = 0))


@jit
def gradLogLike(theta,state,time,target):
	return grad(loglike)(jnp.array(theta,float),state,time,target)



grads = gradLogLike(theta,state,time_o,target)
print(grads)

This outputs

[-8.7592164e-03  4.6842605e-01 -9.8126686e-01  1.3541200e+05
  3.0881677e+03]

Which is what I expect. However, when I combine it into a aesara operation, I get those shape errors.
Another thing to note is, theta in the above test script is a list whereas while sampling it has to be a aesara tensor such as

Cf = pm.Uniform('Cf',lower = -150000, upper = -50000,initval = -80000) # front axle cornering stiffness (N/rad)
Cr = pm.Uniform('Cr',lower = -150000, upper = -50000,initval = -80000) # rear axle cornering stiffness (N/rad)
Iz = pm.Uniform('Iz',lower = 500, upper = 3000,initval = 2450) # yaw moment of inertia (kg.m^2)
sigmaVy = pm.HalfNormal("sigmaVy",sigma = 0.006,initval=0.005) # Noise for lateral velocity
sigmaYr = pm.HalfNormal("sigmaLat_acc",sigma = 0.03,initval=0.03) #Noise for yaw rate
#Theta here is a aesara tensor
theta = at.as_tensor_variable([Cf,Cr,Iz,sigmaVy,sigmaYr])

I think the aesara operation still converts theta into a list before passing it to loglike and grad_loglike but I am unsure if this is where all the problem lies.

In summary, the jax stuff works when tested separately but breaks down when wrapped in a aesara operation.

I didn’t have time to check your script, but in case it helps we have a WIP guide here on how to integrate a JAX function into a Aesara Op / PyMC model: Add guide on how to wrap a JAX function in a Aesara Op by ricardoV94 · Pull Request #302 · pymc-devs/pymc-examples · GitHub

About the lists… Aesara will convert it to a Numpy Array before passing to JAX, not to lists.

1 Like

Thanks for the link it helps a lot!

I have got my sampler up and running but would like to only take gradients with respect to my first input (theta).

My code is currently set up the following way

@jit
def loglike(theta,state,time,targets):
	#Evaluate the model
	mod = odeint(vehicle_bi, state , time,theta,rtol=1e-6, atol=1e-5, mxstep=1000)
	#Get the sigmas from theta
	sigmas = jnp.array(theta[-(targets.shape[1]):],float)
	#Evaluate negetive of log likelihood
	return -jnp.sum(jnp.sum((mod[:,[0,2]] - targets)**2/(2.*sigmas**2))/jnp.linalg.norm(targets,axis = 0))



#Gradient of loglikelihood function
grad_loglike = jit(grad(loglike,argnums=0))



# define a custom aesera operation
class LogLike(at.Op):

	def make_node(self, *inputs):
		# Convert our inputs to symbolic variables
		inputs = [at.as_tensor_variable(inp) for inp in inputs]
		# Define the type of the output returned by the wrapped JAX function
		outputs = [at.dscalar()]
		return Apply(self, inputs, outputs)
    def perform(self, node, inputs, outputs):
		result = loglike(*inputs)
		outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)

	def grad(self, inputs, output_gradients):
		gradients = logprob_grad_op(*inputs)
		return [output_gradients[0] * gradient for gradient in gradients]
#Similarly wrapper class for loglike gradient
class LoglikeGrad(at.Op):

	def make_node(self, *inputs):
		inputs = [at.as_tensor_variable(inp) for inp in inputs]
		outputs = [inp.type() for inp in inputs]
		return Apply(self, inputs, outputs)

	def perform(self, node, inputs, outputs):

		results = grad_loglike(*inputs)
		for i, result in enumerate(results):
			outputs[i][0] = np.asarray(result, dtype=node.outputs[i].dtype)
# Initialize our `Op`s
logp_op = LogLike()
logprob_grad_op = LoglikeGrad()

#Evaluate
logp_op(theta,state,time_o,target).eval()
logprob_grad_op(theta,state,time_o,target)[1].eval()

I get an index error from the gradient operation-

  File "/Users/hussainmustafa/Desktop/research/tutorials/vd_bi_aes.py", line 96, in perform
    outputs[i][0] = np.asarray(result, dtype=node.outputs[i].dtype)
IndexError: list index out of range

I think this is happens because I only define the gradient with respect to theta (argnum = 0) since that is the only gradient I am interested in whereas the operation expects a gradient with respect to all the inputs. I see the suggestion in the PR to use aesara.gradient.grad_not_implemented incase I have gradients that I dont want to be computed but I do not know where to use this.

The code works if I replace grad_loglike = jit(grad(loglike,argnums=0)) with grad_loglike = jit(grad(loglike,argnums=list(range(4)))) but I dont want to be computing gradients wrt to all the inputs.

You can start by making your grad Op return a single output (the one you care about in your main Op).

Then in the grad method of your main Op, you’ll want to multiply it with the right output gradient and mark the others as not_implemented. There’s an example here in the source code of Aesara: aesara/math.py at main · aesara-devs/aesara · GitHub

1 Like