New package: Transforming JAX to Pytensor, for ODEs and other applications

Hello,

There has been this blog post about how to transform JAX functions to Pytensor ops: How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs I wrote a function (icomo.jax2pytensor) that performs those steps automatically. For instance, for the diffrax differential equation solver, it is enough to write solver = icomo.jax2pytensor(diffrax.diffeqsolve). Then solver accepts pytensor.Variable as input and output with the same API as diffrax.diffeqsolve, it doesn’t matter whether they are in a Pytree or not, and the shape inference is also done automatically.

We use this for compartmental models, hence the name ICoMo (Inference of Compartmental Models), but this function is surely useful for other applications. Importantly, jax2pytensor also wraps potentially returned functions, such that they might be used in another function decorated with jax2pytensor. As an example, this allows the definition of an interpolation function, which is often necessary to model time-dependent differential equations, separately from the main differential equations. Note that this feature of wrapped returned functions hasn’t been extensively tested, so please report any issues.

Link to Github: GitHub - Priesemann-Group/icomo: Tools for the inference of compartmental models
Link to SIR-dynamics example: Step-by-step guide — ICoMo Toolbox 1.0.1
Link to API and examples of jax2pytensor: Transform Jax to Pytensor — ICoMo Toolbox 1.0.1

5 Likes

That’s great.

We would like to have native support for this, similar to the existing @pytensor.as_op: Implement helper `@as_jax_op` to wrap JAX functions in PyTensor · Issue #537 · pymc-devs/pytensor · GitHub

I hadn’t seen this issue. I can start a PR for it, but it won’t be before next week.

2 Likes

Hi thank you for creating these helper functions and classes. When this is integrated into PyMC I think it is reasonable to say that you will have made a significant contribution to the usability of ODE parameter estimation in the library – what you have created is just as easy to use as the uber slow pm.ode.DifferentialEquation. My question for you is: if I want to use icomo.jax2pytensor(diffrax.diffeqsolve) in the context of a hierarchical model where each group has its own set of params to be passed to the ode solver (for loop implementation below) should I use pytensor.scan to ‘loop’ over the groups, or create a slightly different version of icomo.jax2pytensor which internally uses jax.vmap to accomplish the same thing.

For Loop Implementation Example:

with pm.Model() as model:
    # . . .
    sol = []
    for sub_idx, subject in enumerate(coords['subject']):
        subject_y0 = [subject_init_vals[sub_idx]]
        #ode_model_params is a list of pm.Deterministic corresponding to the ODE params
        #each element has length = n_subjects
        subject_model_params = [i[sub_idx] for i in ode_model_params]
        subject_timepoints = tp_data_vector
        subject_t0 = subject_timepoints[ 0]
        subject_t1 = subject_timepoints[-1]

        ode_sol = icomo.jax2pytensor(icomo.diffeqsolve)(
                ts_out=subject_timepoints,
                y0=subject_y0,
                args=subject_model_params,
                ODE=ode_func, #ode_func retuns a list of length 1
            ).ys  
        ode_sol = ode_sol[0].flatten()
        sol.append(ode_sol)

    sol = pt.concatenate(sol)
    # . . .