Gaining speed in sampling an ODE

I was really pleased to see DifferentialEquations in pymc3. So, I started working on one problem of mine, where I could need this. Because currently, I do ML estimates, but going Bayesian would be great. I’m already very greatfull for the support I got from @dpananos and @michaelosthege (Return value for 2n-dimensional ODE system). Thanks!

My main problem now is, that pymc3 is much too slow when sampling. Here is an example notebook:


It is quite lengthy, as it is the actual task I am working on. I simulate some data and perform parameter recovery. The original data I work on has the same structure.

I infer parameters for two different models. The first model (A) is an n-dimensional system of ODEs and there are only 2 free parameters. This model seems still doable and I get reasonable results. Still, it is quite slow. And posterior analysis, things like traceplot are very slow as well.

However, the second model (D) which is a 2*n-dimensional ODE system and has 5 free parameters is way too slow to finish in reasonable time on my machine. But this is the model I work on and ML estimates seem to be OK (not shown in the notebook).

I already hacked the DifferentialEquation class to use solve_ivp which is much faster than odeint in my case (this also inspired Solve_ivp for Differential Equation).

Is there anything else to speed up computations? Any advise would be very welcome.

1 Like

I know this seems a bit silly to be saying on PyMC3’s discourse, but have you maybe thought about using Stan for their ODE capability? Their implementation absolutely crushes mine, and so if speed is a concern, that would be my recomendation. DifferentialEquation is still new functionality and there is a lot of work to be done.

2 Likes

I’d also like to add my thanks for starting the DifferentialEquation work, its something that could be very useful for my research. Unfortunately I’ve also found that the speed is a limiting factor, my code basically grinds to a halt if I try anything too complex.

I’ve done some profiling and it would seem that memory allocations are currently a bottle neck, in 1 minute of sampling roughly 20 million allocations are made for a total of 1.25 GB of memory, almost all of this is very quickly de-allocated. Digging a bit further shows that the vast majority of these allocations are below odeint call_odeint_user_function in the stack and come from theano ops. My guess (and it is only a guess) is that this is due to the use of theano in utils.augment_system of the ode code. Whilst it’s clearly very elegant from the programming perspective to get theano to calculate the jacobian, I think the result is that theano is building up and destroying its memory framework for every call to the user function, producing the rather extreme memory allocation use. Let me know if there’s any other data might be useful.

2 Likes

It seems like you profiled at least one source of the bottlenecks quite well. I am still relatively new to optimizing complex mathematical operations, so what would be the obvious way to potentially work on this particular bottleneck? Re-using the old memory?

I am genuinely wondering here. ODE functionality in Pymc3 is already quite amazing!

You’re right: the context switch in the iterations of the ODE solver is a huge problem.

There are a few things that are not too difficult and could lead to some acceleration:

  • numba-compiled functions of sympy-derived augment_system (instead of theano)
  • adjusting absolute/relative tolerances for odeint
  • option to turn off sensitivities & using a gradient-free sampler (only small & easy models)

And then there’s a the optimal solution that is much faster, but for which we need more help to implement:

  1. make cross-OS conda-installable package that wraps sundials
  2. sympy analysis to get augment_system
  3. numba-compile augment_system generated code in a way that it acts directly on sundials data types

The optimal strategy actually avoids all the context switches and also avoids copying data around all the time. @aseyboldt has a proof of concept already, but point 1. as well as a clean object-oriented implementation are still under construction.
Based on Adrians work, I managed to build an entry point for the sympy-based analysis of a user-provided ODE system.
See here:

2 Likes