Theano Op using JAX for lightning-fast ODE inference

I’ve written a Theano Op that uses JAX to solve and autodifferentiate a system of ODEs, which allows parameter estimation via ADVI that’s ~8x faster than SUNODE and ~120x faster than Pymc3’s native DifferentialEquation module. Thanks to easy vectorization, it will probably do even better on a GPU. If there’s interest, I can write this up into a notebook with a more step-by-step explanation.

I thought I’d share what I came up with below, in case anyone else needs it. I worked off of this guide for the SUNODE implementation, these pages for the DifferentialEquation implementation, and this guide for the custom op. The one issue I still have is that the Op doesn’t seem to play nicely with NUTS. I’ll elaborate more on that at the bottom of this post, let me know if anyone has some suggestions!


Imports
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
import theano
import theano.tensor as tt
import seaborn as sns
theano.config.compute_test_value = 'ignore'


from scipy.integrate import odeint
from sunode.wrappers.as_theano import solve_ivp
SEED=2021

import jax
from jax.experimental.ode import odeint as jodeint
import jax.numpy as jnp


from jax.config import config
config.update("jax_enable_x64", True)

Model system

For illustration we’ll use a simple ODE representing an enzymatic conversion of one chemical species to another. Note that this is in the odeint convention inputs order. One caveat of using JAX is that this RHS has to be compatible with JAX. For simple definitions like the one below we don’t have to do anything differently than we usually would, for slightly more complex definitions you’ll just have to use JAX’s numpy wrapper (jax.numpy), but in some instances you’ll have to re-work the internals a bit.

# Simple enzymatic reaction with two species
def rhs(y, t, p):
    S, P = y[0], y[1]
    vmax, K_S = p[0], p[1]
    dPdt = vmax * (S / K_S + S)
    dSdt = -dPdt
    return [
        dSdt,
        dPdt,
    ]
Toy data

rng=np.random.default_rng(SEED)

# Times for observation
times = np.arange(0, 10, 0.5)
S_idx = np.arange(5, len(times))
P_idx = np.arange(12)
S_t = times[S_idx]
P_t = times[P_idx]

y0_true = (10., 2.)
theta_true = vmax, K_S = (0.5, 2.)
sigma = 1

obs = odeint(rhs, t=times, y0=y0_true, args=(theta_true,))
S_obs = rng.normal(obs[S_idx, 0], sigma)
P_obs = rng.normal(obs[P_idx, 1], sigma)

fig, ax = plt.subplots(dpi=120)
plt.plot(S_t, S_obs, label="S", linestyle="none", marker="o", color="red")
plt.plot(P_t, P_obs, label="P", linestyle="none", marker="o", color="blue")
plt.plot(times, obs.T[0], label="S", color="red")
plt.plot(times, obs.T[1], label="P", color="blue")
plt.legend()
plt.xlabel("Time (Seconds)")
plt.ylabel(r"$y(t)$")
plt.show()

The Op

class ODEop(tt.Op):
        
    def __init__(self, solver, vectorized=True):
        # JAX's vmap is all you need to vectorize the solver to work off of a list of parameter values
        self._solver = solver if not vectorized else jax.jit(jax.vmap(solver))
        # JAX's autodifferentiation allows automatic construction of the vector-Jacobian product
        self._vjp = jax.jit(lambda params,grad: jax.vjp(self._solver,params)[1](grad)[0])
        # We need a separate op to allow Theano to calculate the gradient via JAX's vjp
        self._grad_op = ODEGradop(self._vjp)
        
    def make_node(self, p):
        # Tells Theano what to expect in terms of the shape/number of inputs and outputs
        p = theano.tensor.as_tensor_variable(p)
        node = theano.tensor.Apply(self, [p], [p.type()])
        return node

    def perform(self, node, inputs, output):
        # Just calls the solver on the parameters
        params = inputs[0]
        output[0][0] = np.array(self._solver(params))  # get the numerical solution of ODE states

    def grad(self, inputs, output):
        # Theano's gradient calculation
        params = inputs[0]
        grads = output[0] 
        return [self._grad_op(params, grads)]
    
class ODEGradop(tt.Op):
    
    def __init__(self, vjp):
        self._vjp = vjp
        
    def make_node(self, p, g):
        p = theano.tensor.as_tensor_variable(p)
        g = theano.tensor.as_tensor_variable(g)
        node = theano.tensor.Apply(self, [p, g], [g.type()])
        return node

    def perform(self, node, inputs_storage, output_storage):
        params = inputs_storage[0]
        grads = inputs_storage[1]
        out = output_storage[0]
        # Get the numerical vector-Jacobian product
        out[0] = np.array(self._vjp(params,grads))

Here’s a helper class that allows a little more flexibility and utility.

Helper class
class ODE:
    def __init__(self, solver, times, vectorized=True):
        self._times = times
        self._vectorized = vectorized
        self.__solver = solver
        self.build_solvers()
        self.build_op()
    
    def solve(self,params):
        params = np.atleast_2d(params)
        M,N = params.shape
        return self._solver(params) if params.shape[0]==1 else self._vsolver(params)
    
    def build_solvers(self):
        self._solver = jax.jit(lambda params: self.__solver(self._times,params))
        self._vsolver = jax.jit(jax.vmap(self._solver))
    
    def build_op(self):
        self.Op = ODEop(self._solver, vectorized=self.vectorized)
        
    def sample_posterior_fits(self, fit, param_names, n=1000):
        params_array = np.array([fit[p][:n] for p in param_names]).T
        if self.vectorized:
            solutions = jax.vmap(self._vsolver)(params_array)
            return np.array(solutions).reshape([params_array.shape[0], n, 2, -1])
        else:
            solutions = self._vsolver(params_array)
            return np.array(solutions).reshape([n, 2, -1])
        
    @property
    def vectorized(self):
        return self._vectorized
    
    @vectorized.setter
    def vectorized(self, is_vec: bool):
        self._vectorized = is_vec
        self.build_op()
        
    @property
    def times(self):
        return self._times
    
    @vectorized.setter
    def times(self, _times):
        self._times = _times
        self.build_solvers()
        self.build_op()

Model definition with DifferentialEquation or SUNODE

Native Pymc3 implementation
with pm.Model() as p_model:
    sigma = pm.Exponential("sigma", 1)
    vmax = pm.Lognormal("vmax", 0, 1)
    K_S = pm.Lognormal("K_S", 0, 1)
    s0 = pm.Lognormal("s_0", mu=np.log(10), sd=1)
    
    kin_model = pm.ode.DifferentialEquation(
        func=rhs,
        times=times,
        n_states=len(y0_true),
        n_theta=len(theta_true)
    )

    solution = kin_model(
        y0=[s0, y0_true[1]],
        theta=[vmax, K_S],
        return_sens=False
    )

    S_hat = solution.T[0][S_idx]
    P_hat = solution.T[1][P_idx]

    S_lik = pm.Normal("S_lik", mu=S_hat, sd=sigma, observed=S_obs)
    P_lik = pm.Normal("P_lik", mu=P_hat, sd=sigma, observed=P_obs)
SUNODE implementation
def sunode_rhs(t, y, p):
    S, P = y.S, y.P
    vmax, K_S = p.vmax, p.K_S
    dPdt = vmax * (S / K_S + S)
    dSdt = -dPdt
    return {
        'S': dSdt,
        'P': dPdt,
    }

with pm.Model() as s_model:
    sigma = pm.Exponential("sigma", 1)
    vmax = pm.Lognormal("vmax", 0, 1)
    K_S = pm.Lognormal("K_S", 0, 1)
    s0 = pm.Lognormal("s_0", mu=np.log(10), sd=1)
    
    y0 = {
        'S': (s0,()),
        'P': y0_true[1]*np.ones(1)[0]#
    }
    
    params = {
        'vmax': (vmax,()),
        'K_S': (K_S,()),
        '_dummy': (np.array(1.), ()),
    }
    
    solution, *_ = solve_ivp(
        y0=y0,
        params=params,
        rhs=sunode_rhs,
        # The time points where we want to access the solution
        tvals=times,
        t0=times[0],
    )
    
    
    S_hat = solution['S'][S_idx]
    P_hat = solution['P'][P_idx]

    S_lik = pm.Normal("S_lik", mu=S_hat, sd=sigma, observed=S_obs)
    P_lik = pm.Normal("P_lik", mu=P_hat, sd=sigma, observed=P_obs)

JAX Op Implementation

The Op is initialized with a function that takes a list of parameters as an input and returns the ODE solution as a 1D array as an output (there’s probably a way to do this with multi-dimensional . Note that this allows mixing-and-matching of known and unknown parameters.

def get_sol(params):
    s0, vmax, K_S = params
    p0 = y0_true[1]
    solution = jodeint(rhs, (s0,p0), times, (vmax, K_S))
    return jnp.hstack(solution)

reaction_Op = ODEop(get_sol)

# Or, equivalently, with the helper class
reaction = ODE(get_soln, times, vectorized=False)
reaction_Op = reaction.Op

# Just to verify the gradient is implemented correctly
# Not necessary to run every time
theano.config.compute_test_value = 'ignore'
rng = np.random.RandomState(SEED)
theano.gradient.verify_grad(reaction.Op,[test_point],rng=rng)

with pm.Model() as j_model:
    sigma = pm.Exponential("sigma", 1)
    vmax = pm.Lognormal("vmax", 0, 1)
    K_S = pm.Lognormal("K_S", 0, 1)
    s0 = pm.Lognormal("s_0", mu=np.log(10), sd=1)
    #p0 = pm.Normal("blue_0", mu=10, sd=2)
       
    params = tt.stack([s0,vmax,K_S])
    solution = reaction.Op(params).reshape([2,-1])
    
    S_hat = solution[0,:][S_idx]
    P_hat = solution[1,:][P_idx]

    S_lik = pm.Normal("S_lik", mu=S_hat, sd=sigma, observed=S_obs)
    P_lik = pm.Normal("P_lik", mu=P_hat, sd=sigma, observed=P_obs)
ADVI and plotting
with j_model:
    j_approx = pm.fit(
        10000,
        method='advi',
        callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)],
        random_seed=SEED,
    )
    
    j_samples = j_approx.sample(10000)

fits = reaction.sample_posterior_fits(j_samples, param_names=['s_0', 'vmax', 'K_S'])
plot_posterior_fits(fits)

plt.plot(S_t, S_obs, label="S", linestyle="none", marker="o", mec='white', mfc="red")
plt.plot(P_t, P_obs, label="P", linestyle="none", marker="o", mec='white', mfc="blue")
plt.plot(times, obs.T[0], 'k--', lw=2)
plt.plot(times, obs.T[1], 'k--', lw=2)
plt.legend()
plt.xlabel("Time (Seconds)")
plt.ylabel(r"$y(t)$");

JAX Op ODE fit posteriors

ADVI Benchmarking

10000 iterations:

  • DifferentialEquation: 370 seconds
  • SUNODE: 23 seconds
  • JAX Op: 3 seconds

Easy vectorization

What if we want to estimate multiple systems simultaneously? Easy!

Replicate data generation
rng=np.random.default_rng(SEED)

# Times for observation
repl_times = np.arange(0, 10, 0.01)
n_t = len(repl_times)

y0_true = (10., 2.)
theta_true = vmax, K_S = (0.5, 2.)
sigma = 0.5
n_r = 50

repl_y_obs = odeint(rhs, t=repl_times, y0=y0_true, args=(theta_true,))
repl_S_obs = np.random.normal(repl_y_obs[:, 0], sigma, size=(n_r,n_t))
repl_P_obs = np.random.normal(repl_y_obs[:, 1], sigma, size=(n_r,n_t))
reaction.vectorized = True
reaction.times = repl_times 
# note that if you wanted to build the Op directly, you'd have to first define
# `get_sol` in terms of `repl_times`, rather than `times`, then call 
# `ODEop(get_sol, vectorized=True)`

with pm.Model() as vrj_model:
    n_r = repl_S_obs.shape[0]
    sigma = pm.Exponential("sigma", 1)
    vmax = pm.Lognormal("vmax", 0, 1, shape=n_r)
    K_S = pm.Lognormal("K_S", 0, 1, shape=n_r)
    s0 = pm.Lognormal("s_0", mu=np.log(10), sd=1, shape=n_r)
       
    params = tt.stack([s0,vmax,K_S],axis=1)
    solution = reaction.Op(params).reshape([n_r,2,n_t])
    
    S_hat = solution[:,0,:]
    P_hat = solution[:,1,:]

    S_lik = pm.Normal("S_lik", mu=S_hat, sd=sigma, observed=repl_S_obs)
    P_lik = pm.Normal("P_lik", mu=P_hat, sd=sigma, observed=repl_P_obs)

That’s 50 ODE systems with a total of 100,000 data points and 10000 ADVI iterations took just over 8 minutes on my laptop (naively using the non-vectorized Op in a for loop took about 10 minutes). Given that JAX automatically optimizes for your hardware, I’d imagine you’ll see an even bigger performance boost on a GPU (and maybe a TPU :astonished:).

I didn’t really see a performance improvement with a DIY Minibatch implementation, likely because most of the overhead is still the ODEs rather than the likelihood, but here it is in case anyone’s curious:

DIY Minibatch
with pm.Model() as mvrj_model:
    n_r = repl_S_obs.shape[0]
    sigma = pm.Exponential("sigma", 1)
    vmax = pm.Lognormal("vmax", 0, 1, shape=n_r)
    K_S = pm.Lognormal("K_S", 0, 1, shape=n_r)
    s0 = pm.Lognormal("s_0", mu=np.log(10), sd=1, shape=n_r)
       
    params = tt.stack([s0,vmax,K_S],axis=1)
    solution = reaction.Op(params).reshape([n_r,2,n_t])
    
    # DIY Minibatch
    batch_size = theano.shared(20)
    ridx = pm.tt_rng().uniform(size=(batch_size,), low=0, high=n_t-1e-10).astype('int64').sort()
    S_obs_shared = theano.shared(repl_S_obs)
    P_obs_shared = theano.shared(repl_P_obs)
    
    S_hat = solution[:,0,ridx]
    P_hat = solution[:,1,ridx]

    mini_S = S_obs_shared[:,ridx]
    mini_P = P_obs_shared[:,ridx]
    
    total_size = int(batch_size.eval())
    S_lik = pm.Normal("S_lik", mu=S_hat, sd=sigma, observed=mini_S, total_size=total_size)
    P_lik = pm.Normal("P_lik", mu=P_hat, sd=sigma, observed=mini_P, total_size=total_size)

NUTS issue

As I mentioned above, for some reason NUTS doesn’t run with the Op. With default settings it initializes but then just hangs before taking any samples. If I set cores=1, it completes one chain but then hangs part way through the second. This is on Ubuntu (technically WSL on Windows 10), pymc 3.11.2. Any thoughts on how to fix this?

10 Likes

Thanks for sharing!
cc @twiecki @aseyboldt

This is really cool - thanks for sharing. Would be a great contribution to PyMC3!

1 Like

I’d be happy to contribute! Before I do, though, any idea why the Op would work with ADVI but cause NUTS to hang indefinitely? I’d like to sort that out.

Unfortunately not, but I agree that this issue needs to be resolved beforehand.