Defining grad() for custom Theano Op that solves nonlinear system of equations

Hi all,

I’m writing a custom Theano Op that wraps scipy.optimize.root() to find the roots of a nonlinear system of three equations and three unknowns, though I’m only interested in the value of the third root. In pseudocode, it looks like:

def system(unknowns, parameters, constants):
    return [ F[0], F[1], F[2] ]

def solve_system(parameters, constants):
   # x0: some initial guess for roots
   roots = scipy.optimize.root(system, x0, args = (parameters, constants))
   return roots.x[2]  # return the value of the third root

I’ve also written a Theano Op that successfully returns the root of interest as a vector, since I’m looping over a vector of constants and want to see this root at the various constant values. In pseudocode again:

class MyCustomOp(theano.Op):
    itypes = [ tt.dscalar, tt.dvector ]
    otypes = [ tt.dvector ]
   
    def __init__(self, constants):
        self.constants = constants

    def perform(self, node, inputs, outputs):
         scalar_input = inputs[0]
         vector_inputs = inputs[1]

         vector_of_roots = np.empty(N)
         for i in range(N):
             vector_of_roots[i] = solve_system(parameters = [scalar_inputs, vector_inputs],
                                                                    constants = self.constants)
        outputs[0][0] = vector_of_roots

Now this works perfectly fine, I’ve tested it out and MyCustomOp returns an array of correct expected values. My question is: How would I define the grad() so that I can use this in NUTS sampling?

I’ve got two inputs, one scalar (let’s call it S) and one vector (V). There’s three equations in the system. There’s N values in vector_of_roots. As I understand, grad() needs to return a list of size 2, but I don’t know how to format this list.

Thanks in advance!

FWIW, my solution to a related problem was to use JAX to find the gradient for me: Theano Op using JAX for lightning-fast ODE inference

If you can rewrite your system to use jax.optimize.minimize, jax.numpy.roots, or maybe jax.lax.custom_root (if you’re bold), or anything else in jax.scipy then you can use the theano Op I show in that post. My particular application was parameter estimation for a system of ODEs, but I believe the Op should be able to wrap any JAX-friendly function. The big caveat is, last I checked (when that was posted), the Op was fine for ADVI but ran into issues with NUTS, and no one really knows why.

On the other hand… do you really really need NUTS? Might I gently recommend seeing if ABC-SMC could tackle your problem? See the example notebooks here and here. The advantage is you can use arbitrary black-box code to simulate your observations, and rather than a strict likelihood function the acceptibility of a given sample of parameters is determined by some (any) distance metric between the corresponding simulation and your observations. I think the biggest area this fails is if you want to use hierarchical prior structure, but if you have a relatively simple prior structure you should be fine.

None of that exactly answered your question, but hope it helps!

1 Like