Okay, you’ve convinced me. For now, I will go back to a more conventional approach, namely treating the transition rates as constant and moving the stochasticity to the obs_cov
and state_cov
terms.
Unfortunately, I’m still getting hung up on the basic syntax. Here’s another example, based on the forward Euler discretization of a forced spring with damping (x is position, v is velocity). The time-varying component here is the forcing term, which should show up in state_intercept
. There are five versions, trying various combinations of the syntax you provided. I can only get the constant-forcing solution to work.
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint
import arviz as az
from pymc_extras.statespace.core.statespace import PyMCStateSpace
import pytensor.tensor as pt
import pymc as pm
# Damped Spring model
# x'' + 2ax' + (w^2)x = f
##############################
# Set up true model and data #
##############################
# choose versions 0-5
VERSION = 0
# true model parameters
a_true = 0.3
w_true = 0.7
sigma_true = 0.03
# initial conditions
x_init = 2
v_init = 0.5
# simulation parameters
T_init = 0
T_fin = 10
dt = 0.1
N = int(np.ceil( (T_fin - T_init)/dt ))
time = np.arange(T_init, T_fin, dt)
def forcing_fun(t):
if VERSION in [0, 1, 2]:
return 1
if VERSION in [3, 4, 5]:
return np.sin(t)
def exact_model(y, t):
x = y[0]
v = y[1]
# the model equations
dxdt = v
dvdt = -2*a_true*v - w_true**2*x + forcing_fun(t)
return [dxdt, dvdt]
# solve ODE
y = odeint(exact_model, [x_init, v_init], time)
x_true = y[:,0] + np.random.normal(0, sigma_true, size=N)
v_true = y[:,1] + np.random.normal(0, sigma_true, size=N)
# plot results
plt.plot(time, x_true)
plt.plot(time, v_true)
plt.xlabel('time')
plt.show()
###############################
# Set up the stochastic model #
###############################
class SpringModel(PyMCStateSpace):
def __init__(self, dt=1, n_timesteps=100):
# defines the "shape" of the matrices & vectors
k_states = 2 # size of the state vector x
k_posdef = 0 # number of shocks (size of the state covariance matrix Q)
k_endog = 1 # number of observed states
self.dt = dt
self.n_timesteps = n_timesteps
super().__init__(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef)
def make_symbolic_graph(self):
# sets the matrix & vector values
x0 = self.make_and_register_variable("x0", shape=(2,))
ar_params = self.make_and_register_variable("ar_params", shape=(2,)) # [w, a]
sigma_x = self.make_and_register_variable("sigma_x", shape=())
self.ssm["design", 0, 0] = 1
self.ssm["obs_cov", 0, 0] = sigma_x**2
self.ssm["initial_state", :] = x0
self.ssm["transition", :, :] = np.eye(2)
self.ssm["transition", 0, 1] += self.dt
self.ssm["transition", 1, 0] -= self.dt*ar_params[0]**2
self.ssm["transition", 1, 1] -= self.dt*2*ar_params[1]
##############
# This works #
##############
# version 0: constant forcing
if VERSION == 0:
self.ssm["state_intercept", 1] = self.dt
# version 1: constant forcing, with time index included
if VERSION == 1:
self.ssm["state_intercept", :, 1] = self.dt
################
# This doesn't #
################
# version 2: explicitly declaring the size of the time dimension
if VERSION == 2:
S_I = pt.zeros((self.n_timesteps, self.k_states))
S_I = S_I[:, 1].set(self.dt)
self.ssm["state_intercept"] = S_I
# versions 3 and 4 introduce a new variable to be used when the forcing term isn't constant
if VERSION == 3:
forcing = self.make_and_register_variable("forcing", shape=(self.n_timesteps,))
self.ssm["state_intercept", :, 1] = self.dt*forcing
if VERSION == 4:
forcing = self.make_and_register_variable("forcing", shape=(self.n_timesteps,))
S_I = pt.zeros((self.n_timesteps, self.k_states))
S_I = S_I[:, 1].set(self.dt*forcing)
self.ssm["state_intercept"] = S_I
# version 5: skipping the deterministic variable as an intermediary
if VERSION == 5:
self.ssm["state_intercept", :, 1] = self.dt*pt.as_tensor_variable(forcing_fun(time))
# All parameter names created with "make_and_register_variable" must
# be registered in a class property called param_names.
@property
def param_names(self):
if VERSION in [0, 1, 2, 5]:
return ["x0", "ar_params", "sigma_x"]
if VERSION in [3, 4]:
return ["x0", "ar_params", "sigma_x", "forcing"]
#################
# Fit the model #
#################
sm = SpringModel(dt=dt, n_timesteps=N)
with pm.Model() as pymc_mod:
x0 = pm.Normal("x0", mu=3, sigma=0.25, shape=(2,))
ar_params = pm.Gamma("ar_params", alpha=2, beta=0.5, shape=(2,))
sigma_x = pm.Exponential("sigma_x", 1)
if VERSION in [3, 4]:
forcing = pm.Deterministic("forcing", pt.as_tensor_variable(forcing_fun(time)))
sm.build_statespace_graph(data=np.asmatrix(x_true).T, mode="JAX")
idata = pm.sample(
nuts_sampler="nutpie",
nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"}
)
az.plot_trace(idata, var_names=["x0", "ar_params", "sigma_x"])
plt.show()
az.plot_posterior(
idata, var_names=["x0", "ar_params", "sigma_x"], ref_val=[x_init, v_init, w_true, a_true, sigma_true]
)
plt.show()
Versions 2 and 4 first throw a warning:
UserWarning: Skipping CheckAndRaise
Op (assertion: The first dimension of a time varying matrix (the time dimension) must be equal to the first dimension of the data (the time dimension).) as JAX tracing would remove it.
Before hanging for a long time and then finally crashing with:
RuntimeError: All initialization points failed
Caused by:
Logp function returned error: Python error: IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int8)>with, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
Versions 3 and 5 seem most promising, in that the error looks like it’s just a matter of me having specified a shape/formatted a matrix incorrectly somewhere. Unfortunately, I’m unable to locate the problem.
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: SpecifyShape([ 0. … .45753589], 1)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
(lots of text)
AssertionError: SpecifyShape: dim 0 of input has shape 100, expected 1.
Any advice on how to fix this would be appreciated.