I think I worked out how to do that. I implemented a more general integration op:
from scipy.integrate import quad
import theano
import theano.tensor as tt
import numpy as np
class Integrate(theano.Op):
def __init__(self, expr, var, *extra_vars):
super().__init__()
self._expr = expr
self._var = var
self._extra_vars = extra_vars
self._func = theano.function(
[var] + list(extra_vars),
self._expr,
on_unused_input='ignore')
def make_node(self, start, stop, *extra_vars):
self._extra_vars_node = extra_vars
assert len(self._extra_vars) == len(extra_vars)
self._start = start
self._stop = stop
vars = [start, stop] + list(extra_vars)
return theano.Apply(self, vars, [tt.dscalar().type()])
def perform(self, node, inputs, out):
start, stop, *args = inputs
val = quad(self._func, start, stop, args=tuple(args))[0]
out[0][0] = np.array(val)
def grad(self, inputs, grads):
start, stop, *args = inputs
out, = grads
replace = dict(zip(self._extra_vars, args))
replace_ = replace.copy()
replace_[self._var] = start
dstart = out * theano.clone(-self._expr, replace=replace_)
replace_ = replace.copy()
replace_[self._var] = stop
dstop = out * theano.clone(self._expr, replace=replace_)
grads = tt.grad(self._expr, self._extra_vars)
dargs = []
for grad in grads:
integrate = Integrate(grad, self._var, self._extra_vars)
darg = out * integrate(start, stop, *args)
dargs.append(darg)
return [dstart, dstop] + dargs
## Basic usage
# We define the function we want to integrate
x = tt.dscalar('x')
x.tag.test_value = np.zeros(())
a = tt.dscalar('a')
a.tag.test_value = np.ones(())
func = a ** 2 * x**2
integrate = Integrate(func, x, a)
# Check gradients
from theano.tests.unittest_tools import verify_grad
verify_grad(integrate, (np.array(0.), np.array(1.), np.array(2.)))
verify_grad(integrate, (np.array(-2.), np.array(5.), np.array(8.)))
# Now, we define values for the integral
start = tt.dscalar('start')
start.tag.test_value = np.zeros(())
stop = tt.dscalar('stop')
stop.tag.test_value = np.ones(())
a_ = tt.dscalar('a_')
a_.tag.test_value = np.ones(())
# Note, that a_ != a
val = integrate(start, stop, a_)
# Evaluate the integral and derivatives
val.eval({start: 0., stop: 1., a_: 2.})
tt.grad(val, a_).eval({start: -2, stop: 1, a_: 2.})
tt.grad(val, start).eval({start: 1., stop: 2., a_: 2.})
You can evaluate the integral and its derivatives like this:
val.eval({start: 0., stop: 1., a_: 2.})
tt.grad(val, a_).eval({start: -2, stop: 1, a_: 2.})
You can use this in PyMC3 like this:
import pymc3 as pm
## Usage in PyMC3
with pm.Model() as model:
start = pm.Normal('start', -5, 1)
stop = pm.Normal('stop', 5, 1)
a = pm.Normal('a', 0.5, 1)
# Define the function to integrate in plain theano
x = tt.dscalar('x_')
x.tag.test_value = np.zeros(())
a_ = tt.dscalar('a_')
a_.tag.test_value = np.ones(())
func = a_ ** 2 * x**2
integrate = Integrate(func, x, a_)
# Now we plug in the values from the model.
# The `a_` from above corresponds to the `a` here.
val = integrate(start, stop, a)
pm.Normal('y', mu=val, sd=1, observed=10)
You might want to test the integration Op a bit more, I only did some very basic checks. Just plug in different functions and try to break it. If you find a problem I’d like to hear about it
If you have trouble with this in your use-case, feel free to ask.