Best practice for nonlinear, time-dependent PDE likelihood: scan(), custom Op with JAX, custom Op with FEniCS

Thanks @junpenglao. I was reading up on jax.scipy.optimize() a bit more and I noticed it does not currently support differentiation (this is a planned feature though). So it would still be useful for solving the nonlinear PDE at each time step but the gradients that pymc.NUTS() require would need to be manually calculated. I think this is doable but its still not obvious to me if this implementation will faster than simply using aesara.scan() or a FEniCS implementation.