Hello,
There has been this blog post about how to transform JAX functions to Pytensor ops: How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs I wrote a function (icomo.jax2pytensor) that performs those steps automatically. For instance, for the diffrax differential equation solver, it is enough to write solver = icomo.jax2pytensor(diffrax.diffeqsolve)
. Then solver
accepts pytensor.Variable
as input and output with the same API as diffrax.diffeqsolve
, it doesn’t matter whether they are in a Pytree or not, and the shape inference is also done automatically.
We use this for compartmental models, hence the name ICoMo (Inference of Compartmental Models), but this function is surely useful for other applications. Importantly, jax2pytensor also wraps potentially returned functions, such that they might be used in another function decorated with jax2pytensor. As an example, this allows the definition of an interpolation function, which is often necessary to model time-dependent differential equations, separately from the main differential equations. Note that this feature of wrapped returned functions hasn’t been extensively tested, so please report any issues.
Link to Github: GitHub - Priesemann-Group/icomo: Tools for the inference of compartmental models
Link to SIR-dynamics example: Step-by-step guide — ICoMo Toolbox 1.0.1
Link to API and examples of jax2pytensor: Transform Jax to Pytensor — ICoMo Toolbox 1.0.1