This question is more aligned with Jax (not PyMC3 at all, actually). I’m curious if the following workflow is generally valid for approximating differential equation solutions:
- Start with some function
- Compute gradients wrt each independent variable
- Use leapfrog integration (or other symplectic method) given step size and path length to estimate the endpoint
Any thoughts both on efficacy of this approach and functionality of Jax?