Time dependant ODE parameters / functions

I have an ODE function ode(y,t,theta, influent) where influent is an array of values at different times (experimental data for input to system).

I have this function in my ODE:

def get_influent_at_time(t, influent):

    time_data = influent[:, 0] # Get the time data, first column of the influent data

    # Interpolate to find value of each input at time t
   input1 = np.interp(t, time_data, influent[:,2])
   input2 = .....
   ....
 
   input_at_time_t = [input1, input2 ....]

This works fine when my t supplied is an actual value (directly with solve_ivp or odeint), but within PyMC, the DifferentialEquation used in sampling does not support this as t is a TensorVaraible. It aslo does not work with the sunode library as t is a sympy.Symbol.

Anyone have any advice or ways to get around this?

1 Like

I got a bit lost at the part where you defined

    # Interpolate to find value of each input at time t
   input1 = np.interp(t, time_data, influent[:,2])

For my time-varying that worked within the DifferentialEquation function of PyMC , I defined several parameters in this way:

from pymc.ode import DifferentialEquation

def ode_model(state,t,params): 
     I_W_prime = I_Wa -   b(t,q = b_k)*I_W

mcmc_ode = DifferentialEquation(
    func=ode_model,
    times=tspan,
    n_states=...,
    n_theta=...,
    t0=t0
)

In this case, b(t,k=b_k) is time-varying. It is a partial function, which is define in this way:

from functools import partial
import numpy as np

def Logisticfunc(t,  q, t0 = 0):
    return 1 / (1 + np.exp(-q * (t - t0)))

b = partial(Logisticfunc, q=10.2)

Because it is a Python’s partial function, it is only evaluated when PyMC called this b(t) within the ode_model. The b_k is a parameter and go to params of def ode_model(state,t,params):

I basically want to interpolate between my data to find the values of input1, input2 … etc. at a specific time t. So it is not an “equation” f(t), like b(t) = 1 / (exp(-q*(t-t0))).

Do you think I could adapt it somehow to be in line with what worked for you?

Ah I see.
I am not too familiar with finding values at specific time t in this way though.

Do you need the posterior distribution for each of the input1, input2…? Or do you just need the values of input1, input2?

If it is the former, you will need to do a pm.Gamma or pm.Normal for each of the input1, input2.

pm.Gamma(input1, alpha=alpha, beta=beta)

If it is the latter, you will just need to solve_ivp or odeint outside of pm?

Yes I need the value of each input at time t, not the distribution. But I have other parameters in my model that I want distributions over, these are just time dependent inputs to the model.

I understood your question now!!

Yes, that is the trouble with working with TensorVariable.
Someone from PyMC developers’ team can clarify this point, but I think TensorVariable only support arithmetic operations like addition, subtraction etc…

For my case, I wanted to use np.gradient to get new_cases in my model. But TensorVariable also does not support np.gradient. So instead of
new_cases = np.gradient(modelled_cum_dths)

I did this instead,

new_cases = modelled_cum_dths[1:] - modelled_cum_dths[:-1]

So my way was to convert unsupported functions to arithmetic operations. But if anyone else knows a better way, let me know.

If you roughly know the functional family of your input1 (or you are okay with approximating your data), you could actually replace your np.interp (interpolation function) with a partial function b(t), c(t) … like how I did.

I knew my time-varying data follow logistic trends. So I just use logistic functions to model the time-varying trends. If it is sine, cosine… trends (i.e., seasonal flu) maybe such function family is more appropriate.

Of course, doing this way, you will lose some accuracy to your data because interpolation is more flexible/ more precise to your data.

For now I just fitted it using a polynomial of some degree that doesn’t produce very high peaks within my data range, and even tried a piecewise function which didn’t work, so yeah I am unsuire if there is a better way to do it.

If you’re willing to try a different ODE solver, there’s Diffrax which is an ODE solver for JAX. I haven’t used it myself, but it looks like it has interpolation included. There are ways to get it to work in PyMC, but it might end up being a fair amount of work.