How to use JAX ODEs and Neural Networks in PyMC

"PyMC strength comes from its expressiveness. If you have a data-generating process and want to infer parameters of interest, all you need to do is write it down, choose some priors and let it sample.

Sometimes this is easier said than done, especially the “write it down” part. With Python’s rich ecosystem, it’s often the case that you already have a generative function, but it’s written in another framework, and you would like to use it in PyMC. Thanks to the highly composable nature of the PyMC backend, this is simple. Even simpler if that framework can also provide you gradients for free!

In this blog post, we show how you can reuse code from another popular auto-diff framework, JAX, directly in PyMC."

3 Likes