Defining grad() for custom Theano Op that solves nonlinear system of equations

FWIW, my solution to a related problem was to use JAX to find the gradient for me: Theano Op using JAX for lightning-fast ODE inference

If you can rewrite your system to use jax.optimize.minimize, jax.numpy.roots, or maybe jax.lax.custom_root (if you’re bold), or anything else in jax.scipy then you can use the theano Op I show in that post. My particular application was parameter estimation for a system of ODEs, but I believe the Op should be able to wrap any JAX-friendly function. The big caveat is, last I checked (when that was posted), the Op was fine for ADVI but ran into issues with NUTS, and no one really knows why.

On the other hand… do you really really need NUTS? Might I gently recommend seeing if ABC-SMC could tackle your problem? See the example notebooks here and here. The advantage is you can use arbitrary black-box code to simulate your observations, and rather than a strict likelihood function the acceptibility of a given sample of parameters is determined by some (any) distance metric between the corresponding simulation and your observations. I think the biggest area this fails is if you want to use hierarchical prior structure, but if you have a relatively simple prior structure you should be fine.

None of that exactly answered your question, but hope it helps!

1 Like