Time dependant ODE parameters / functions

If you’re willing to try a different ODE solver, there’s Diffrax which is an ODE solver for JAX. I haven’t used it myself, but it looks like it has interpolation included. There are ways to get it to work in PyMC, but it might end up being a fair amount of work.