I read the source code and readme and did some trial and error.
As a result, I solved it by creating a monkey patch to sunode.wrappers.as_pytensor.solve_ivp.
Also, I have seen that this dramatically reduces error terminations in the integral (over 90% reduction).
I think this is because the probability density is highly skewed with respect to the parameter space.
For other beginners who are having trouble, I’ll copy and paste my code. It’s a messy and terrible (especially import), but I hope it helps someone.
FYI, @aseyboldt.
(But please let me know if there seems to be a problem.)
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
import matplotlib.pyplot as plt
from scipy.integrate import odeint
import pandas as pd
from pytensor import function
import pytensor.tensor as pt
from pytensor.graph.fg import MissingInputError
from pytensor.graph.op import Op
from pytensor.graph.basic import Constant, Variable
from pytensor.gradient import grad_not_implemented
from typing import Dict, Optional, Any, Callable
import sympy as sym
import sunode
from sunode import basic, symode, solver
from sunode.basic import CPointer, ERRORS, lib, ffi, check, check_ptr, Borrows, check_code
import sunode.wrappers.as_pytensor
from sunode.wrappers.as_pytensor import EvalRhs, SolveODE, SolveODEAdjoint, SolveODEAdjointBackward, solve_ivp
from sunode.solver import SolverError
from sunode.dtypesubset import as_flattened
import sunode.symode.problem
from sunode.symode.problem import SympyProblem
###monkey patch for solve_ivp to increase maxstep
def new_solve_ivp(
t0: float,
y0: np.ndarray,
params: Dict[str, Any],
tvals: np.ndarray,
rhs: Callable[[sym.Symbol, np.ndarray, np.ndarray], Dict[str, Any]],
derivatives: str = 'adjoint',
coords: Optional[Dict[str, pd.Index]] = None,
make_solver=None,
derivative_subset=None,
solver_kwargs=None,
simplify=None,
) -> Any:
dtype = basic.data_dtype
mxstep = 500000 # set mxstep (Initial value is 5000
if solver_kwargs is None:
solver_kwargs={}
if derivatives == "forward":
params = params.copy()
params["__initial_values"] = y0
def read_dict(vals, name=None):
if isinstance(vals, dict):
return {name: read_dict(item, name) for name, item in vals.items()}
else:
if isinstance(vals, tuple):
tensor, dim_names = vals
else:
tensor, dim_names = vals, pt.as_tensor_variable(vals, dtype="float64").type.shape
if any(d is None for d in dim_names):
raise ValueError(
'Shapes of tensors need to be statically known or given explicitly.'
)
if isinstance(dim_names, (str, int)):
dim_names = (dim_names,)
tensor = pt.as_tensor_variable(tensor, dtype="float64")
if tensor.ndim != len(dim_names):
raise ValueError(
f"Dimension mismatch for {name}: Value has rank {tensor.ndim}, "
f"but {len(dim_names)} was specified."
)
assert np.dtype(tensor.dtype) == dtype, tensor
tensor_dtype = np.dtype(tensor.dtype)
if tensor_dtype != dtype:
raise ValueError(
f"Dtype mismatch for {name}: Got {tensor_dtype} but expected {dtype}."
)
return dim_names
y0_dims = read_dict(y0)
params_dims = read_dict(params)
if derivative_subset is None:
derivative_subset = []
for path, val in as_flattened(params).items():
if isinstance(val, tuple):
tensor, _ = val
else:
tensor = val
if isinstance(tensor, Variable):
if not isinstance(tensor, Constant):
derivative_subset.append(path)
problem = symode.problem.SympyProblem(
params_dims, y0_dims, rhs, derivative_subset, coords=coords, simplify=simplify)
flat_tensors = as_flattened(params)
vars = []
for path in problem.params_subset.subset_paths:
tensor = flat_tensors[path]
if isinstance(tensor, tuple):
tensor, _ = tensor
vars.append(pt.as_tensor_variable(tensor, dtype="float64").reshape((-1,)))
if vars:
params_subs_flat = pt.concatenate(vars)
else:
params_subs_flat = pt.as_tensor_variable(np.zeros(0), dtype="float64")
vars = []
for path in problem.params_subset.remainder.subset_paths:
tensor = flat_tensors[path]
if isinstance(tensor, tuple):
tensor, _ = tensor
vars.append(pt.as_tensor_variable(tensor, dtype="float64").reshape((-1,)))
if vars:
params_remaining_flat = pt.concatenate(vars)
else:
params_remaining_flat = pt.as_tensor_variable(np.zeros(0), dtype="float64")
flat_tensors = as_flattened(y0)
vars = []
for path in problem.state_subset.paths:
tensor = flat_tensors[path]
if isinstance(tensor, tuple):
tensor, _ = tensor
vars.append(pt.as_tensor_variable(tensor, dtype="float64").reshape((-1,)))
y0_flat = pt.concatenate(vars)
t0 = pt.as_tensor_variable(t0, dtype="float64")
tvals = pt.as_tensor_variable(tvals, dtype="float64")
if derivatives == 'adjoint':
sol = solver.AdjointSolver(problem, **solver_kwargs)
lib.CVodeSetMaxNumSteps(sol._ode, mxstep) # mxstep setting
print("patched!")
wrapper = SolveODEAdjoint(sol)
flat_solution = wrapper(y0_flat, params_subs_flat, params_remaining_flat, t0, tvals)
solution = problem.flat_solution_as_dict(flat_solution)
return solution, flat_solution, problem, sol, y0_flat, params_subs_flat
elif derivatives == 'forward':
if not "sens_mode" in solver_kwargs:
raise ValueError("When `derivatives=True`, the `solver_kwargs` must contain one of `sens_mode={\"simultaneous\" | \"staggered\"}`.")
sol = solver.Solver(problem, **solver_kwargs)
lib.CVodeSetMaxNumSteps(sol._ode, mxstep) # mxstep setting
print("patched!")
wrapper = SolveODE(sol)
flat_solution, flat_sens = wrapper(y0_flat, params_subs_flat, params_remaining_flat, t0, tvals)
solution = problem.flat_solution_as_dict(flat_solution)
return solution, flat_solution, problem, sol, y0_flat, params_subs_flat, flat_sens, wrapper
elif derivatives in [None, False]:
sol = solver.Solver(problem, sens_mode=False)
lib.CVodeSetMaxNumSteps(sol._ode, mxstep) # mxstep setting
print("patched!")
assert False
sunode.wrappers.as_pytensor.solve_ivp = new_solve_ivp
### patch end