Custom theano Op to do numerical integration

I think I worked out how to do that. I implemented a more general integration op:

from scipy.integrate import quad
import theano
import theano.tensor as tt
import numpy as np


class Integrate(theano.Op):
    def __init__(self, expr, var, *extra_vars):
        super().__init__()
        self._expr = expr
        self._var = var
        self._extra_vars = extra_vars
        self._func = theano.function(
            [var] + list(extra_vars),
            self._expr,
            on_unused_input='ignore')
    
    def make_node(self, start, stop, *extra_vars):
        self._extra_vars_node = extra_vars
        assert len(self._extra_vars) == len(extra_vars)
        self._start = start
        self._stop = stop
        vars = [start, stop] + list(extra_vars)
        return theano.Apply(self, vars, [tt.dscalar().type()])
    
    def perform(self, node, inputs, out):
        start, stop, *args = inputs
        val = quad(self._func, start, stop, args=tuple(args))[0]
        out[0][0] = np.array(val)
        
    def grad(self, inputs, grads):
        start, stop, *args = inputs
        out, = grads
        replace = dict(zip(self._extra_vars, args))
        
        replace_ = replace.copy()
        replace_[self._var] = start
        dstart = out * theano.clone(-self._expr, replace=replace_)
        
        replace_ = replace.copy()
        replace_[self._var] = stop
        dstop = out * theano.clone(self._expr, replace=replace_)

        grads = tt.grad(self._expr, self._extra_vars)
        dargs = []
        for grad in grads:
            integrate = Integrate(grad, self._var, self._extra_vars)
            darg = out * integrate(start, stop, *args)
            dargs.append(darg)
            
        return [dstart, dstop] + dargs

    
## Basic usage

# We define the function we want to integrate
x = tt.dscalar('x')
x.tag.test_value = np.zeros(())
a = tt.dscalar('a')
a.tag.test_value = np.ones(())

func = a ** 2 * x**2
integrate = Integrate(func, x, a)

# Check gradients
from theano.tests.unittest_tools import verify_grad
verify_grad(integrate, (np.array(0.), np.array(1.), np.array(2.)))
verify_grad(integrate, (np.array(-2.), np.array(5.), np.array(8.)))


# Now, we define values for the integral
start = tt.dscalar('start')
start.tag.test_value = np.zeros(())
stop = tt.dscalar('stop')
stop.tag.test_value = np.ones(())
a_ = tt.dscalar('a_')
a_.tag.test_value = np.ones(())

# Note, that a_ != a
val = integrate(start, stop, a_)

# Evaluate the integral and derivatives
val.eval({start: 0., stop: 1., a_: 2.})
tt.grad(val, a_).eval({start: -2, stop: 1, a_: 2.})
tt.grad(val, start).eval({start: 1., stop: 2., a_: 2.})

You can evaluate the integral and its derivatives like this:

val.eval({start: 0., stop: 1., a_: 2.})
tt.grad(val, a_).eval({start: -2, stop: 1, a_: 2.})

You can use this in PyMC3 like this:

import pymc3 as pm

## Usage in PyMC3
with pm.Model() as model:
    start = pm.Normal('start', -5, 1)
    stop = pm.Normal('stop', 5, 1)
    a = pm.Normal('a', 0.5, 1)
    
    # Define the function to integrate in plain theano
    x = tt.dscalar('x_')
    x.tag.test_value = np.zeros(())
    a_ = tt.dscalar('a_')
    a_.tag.test_value = np.ones(())

    func = a_ ** 2 * x**2
    integrate = Integrate(func, x, a_)

    # Now we plug in the values from the model.
    # The `a_` from above corresponds to the `a` here.
    val = integrate(start, stop, a)
    pm.Normal('y', mu=val, sd=1, observed=10)

You might want to test the integration Op a bit more, I only did some very basic checks. Just plug in different functions and try to break it. If you find a problem I’d like to hear about it :slightly_smiling_face:

If you have trouble with this in your use-case, feel free to ask.

3 Likes