How to improve the efficiency of loop, and find a root

I don’t have a good answer for you but here are some thoughts:

  • AFAIK, there’s no out-of-the-box differentiable root finder for Aesara yet, but you can find one for the Jax / NumPyro ecosystem here: jax.scipy.optimize.minimize — JAX documentation

  • You could try wrapping a solver in an Aesara Op, but this would require getting the gradients of the solver’s operations.

  • Supposing your data is of modest dimension ( N \ge 50), 12 minutes for sampling this kind of model doesn’t sound too outrageous. These types of problems are pretty hard in general and I’m guessing the posterior is going to be rather complicated because of the nonlinearities in this model.

  • You may want to search the existing Discourse topics for threads related to “root”, “optimization”, “Newton” or related keywords. Here’s an example of one such thread: Defining grad() for custom Theano Op that solves nonlinear system of equations - #2 by BioGoertz.

4 Likes