Sunode pytensor rewrite failure error?

Hi all. This is maybe a technical sunode question so apologies if it’s in the wrong place.

I’m trying to implement PyMC on an ODE model using sunode (I was able to get what I need to work just using scipy odeint, but it is super slow - so trying to use sunode).

I’m just trying to run the example given in the sunode docs (Quickstart with PyMC — sunode documentation), just pasted into a jupyter notebook. I get what seems to be an error within pytensor. The full error is here
pytensor_compilation_error__37uleym.txt (21.7 KB)
but the key bits seem to be:

> ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
> ERROR (pytensor.graph.rewriting.basic): node: Cast{float32}(50)

The end of this error complains about not being able to find unistd.h. I’m using pymc 5.13.0 and sunode 0.5.0 (using the default sunode from conda, 0.2.2, caused errors where sunode.wrappers.as_pytensor wasn’t defined). It seems like some issue with the way that pytensor is being compiled (?).

The full notebook cell I am running is below. I think this issue is probably something to do with my installation, but I am a bit stumped as to what to do. Any advice would be appreciated, thanks!

Full code that fails (just the pasted sunode example):

import numpy as np
import sunode
import sunode.wrappers.as_pytensor
import pymc as pm

times = np.arange(1900,1921,1)
lynx_data = np.array([
    4.0, 6.1, 9.8, 35.2, 59.4, 41.7, 19.0, 13.0, 8.3, 9.1, 7.4,
    8.0, 12.3, 19.5, 45.7, 51.1, 29.7, 15.8, 9.7, 10.1, 8.6
])
hare_data = np.array([
    30.0, 47.2, 70.2, 77.4, 36.3, 20.6, 18.1, 21.4, 22.0, 25.4,
    27.1, 40.3, 57.0, 76.6, 52.3, 19.5, 11.2, 7.6, 14.6, 16.2, 24.7
])

def lotka_volterra(t, y, p):
    """Right hand side of Lotka-Volterra equation.

    All inputs are dataclasses of sympy variables, or in the case
    of non-scalar variables numpy arrays of sympy variables.
    """
    return {
        'hares': p.alpha * y.hares - p.beta * y.lynxes * y.hares,
        'lynxes': p.delta * y.hares * y.lynxes - p.gamma * y.lynxes,
    }

with pm.Model() as model:
    hares_start = pm.HalfNormal('hares_start', sigma=50)
    lynx_start = pm.HalfNormal('lynx_start', sigma=50)

    ratio = pm.Beta('ratio', alpha=0.5, beta=0.5)

    fixed_hares = pm.HalfNormal('fixed_hares', sigma=50)
    fixed_lynx = pm.Deterministic('fixed_lynx', ratio * fixed_hares)

    period = pm.Gamma('period', mu=10, sigma=1)
    freq = pm.Deterministic('freq', 2 * np.pi / period)

    log_speed_ratio = pm.Normal('log_speed_ratio', mu=0, sigma=0.1)
    speed_ratio = np.exp(log_speed_ratio)

    # Compute the parameters of the ode based on our prior parameters
    alpha = pm.Deterministic('alpha', freq * speed_ratio * ratio)
    beta = pm.Deterministic('beta', freq * speed_ratio / fixed_hares)
    gamma = pm.Deterministic('gamma', freq / speed_ratio / ratio)
    delta = pm.Deterministic('delta', freq / speed_ratio / fixed_hares / ratio)

with model:
    y0 = {
        # The initial number of hares is the random variable `hares_start`,
        # and it has shape (), so it is a scalar value.
        'hares': (hares_start, ()),
        'lynxes': (lynx_start, ()),
    }

    params = {
        'alpha': (alpha, ()),
        'beta': (beta, ()),
        'gamma': (gamma, ()),
        'delta': (delta, ()),
        # Parameters (or initial states) do not have to be random variables,
        # they can also be fixed numpy values. In this case the shape
        # is infered automatically. Sunode will not compute derivatives
        # with respect to fixed parameters or initial states.
        'unused_extra': np.zeros(5),
    }

with model:
    from sunode.wrappers.as_pytensor import solve_ivp
    solution, *_ = solve_ivp(
        y0=y0,
        params=params,
        rhs=lotka_volterra,
        # The time points where we want to access the solution
        tvals=times,
        t0=times[0],
    )

with model:
    # We can access the individual variables of the solution using the
    # variable names.
    pm.Deterministic('hares_mu', solution['hares'])
    pm.Deterministic('lynxes_mu', solution['lynxes'])

    sd = pm.HalfNormal('sd')
    pm.LogNormal('hares', mu=solution['hares'], sigma=sd, observed=hare_data)
    pm.LogNormal('lynxes', mu=solution['lynxes'], sigma=sd, observed=lynx_data)

    trace = pm.sample()