Jax (and leapfrog integration) for differential equation approximate solutions?

This question is more aligned with Jax (not PyMC3 at all, actually). I’m curious if the following workflow is generally valid for approximating differential equation solutions:

  1. Start with some function
  2. Compute gradients wrt each independent variable
  3. Use leapfrog integration (or other symplectic method) given step size and path length to estimate the endpoint

Any thoughts both on efficacy of this approach and functionality of Jax?

Any reason you wouldn’t want to solve the ODE directly with JAX, like this?

FWIW, that Op should work for any JAX-friendly function, it doesn’t have to be an ODE.

1 Like