Hi yunus, maybe give the approach I outlined here a try: Theano Op using JAX for lightning-fast ODE inference
It doesn’t work with NUTS, for some unknown reason, but it does work with ADVI, and it’s much faster. It’s also convenient in that you can write your ODE in a familiar way. You could do FullRankADVI to capture the (inevitable) correlations between ODE parameters.