Does the state_space model support time-dependent transition matrices?

The following model runs for me without issue:

import numpy as np
import pandas as pd
import pytensor
import pytensor.tensor as pt
import pymc as pm
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc_extras.statespace.models.utilities import make_default_coords
from pymc_extras.statespace.core import PyMCStateSpace

class SpringModel(PyMCStateSpace):
    def __init__(self, dt=1, n_timesteps=100, estimate_force=False):
        # defines the "shape" of the matrices & vectors
        k_states = 2  # size of the state vector x
        k_posdef = 0  # number of shocks (size of the state covariance matrix Q)
        k_endog = 1  # number of observed states

        self.dt = dt
        self.n_timesteps = n_timesteps
        self.estimate_force = estimate_force

        super().__init__(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef)

    def make_symbolic_graph(self):
        # sets the matrix & vector values

        x0 = self.make_and_register_variable("x0", shape=(2,))
        P0 = self.make_and_register_variable('P0', shape=(2, 2))
        
        self.ssm['initial_state_cov'] = P0
        
        ar_params = self.make_and_register_variable("ar_params", shape=(2,)) # [w, a]
        sigma_x = self.make_and_register_variable("sigma_x", shape=())
        

        self.ssm["design", 0, 0] = 1
        self.ssm["obs_cov", 0, 0] = sigma_x**2

        self.ssm["initial_state", :] = x0

        self.ssm["transition", :, :] = np.eye(2)
        self.ssm["transition", 0, 1] += self.dt
        self.ssm["transition", 1, 0] -= self.dt*ar_params[0]**2
        self.ssm["transition", 1, 1] -= self.dt*2*ar_params[1]

        # version 2: explicitly declaring the size of the time dimension
        S_I = pt.zeros((self.n_timesteps, self.k_states))
        
        if self.estimate_force:
            forcing = self.make_and_register_variable("forcing", shape=(self.n_timesteps,))
            S_I = S_I[:, 1].set(self.dt*forcing)
        
        else:
            S_I = S_I[:, 1].set(self.dt)
        
        self.ssm["state_intercept"] = S_I
        
    @property
    def param_names(self):
        params = ["x0", "P0", "ar_params", "sigma_x"]
        if self.estimate_force:
            params += ['forcing']
        
        return params
    
    @property
    def state_names(self):
        return ['x1', 'x2']
    
    @property
    def shock_names(self):
        return []
    
    @property
    def observed_states(self):
        return ['x1']
    
    @property
    def coords(self):
        return make_default_coords(self)



dt = 0.1
N = 100
data = pd.DataFrame(np.nan, index=pd.RangeIndex(N), columns=['x1'])
sm = SpringModel(dt=dt, n_timesteps=N, estimate_force=True)

with pm.Model() as pymc_mod:
    x0 = pm.Normal("x0", mu=3, sigma=0.25, shape=(2,))
    P0_diag = pm.Gamma('P0_diag', alpha=2, beta=1, shape=(2,))
    P0 = pm.Deterministic('P0', pt.diag(P0_diag))
    
    ar_params = pm.Gamma("ar_params", alpha=2, beta=0.5, shape=(2,))
    sigma_x = pm.Exponential("sigma_x", 1)

    # Random walk prior on forcing term
    forcing = pm.Normal('forcing', 0, 0.1, shape=(N,)).cumsum()

    sm.build_statespace_graph(data=data, mode="JAX")
    idata = pm.sample(
        nuts_sampler="nutpie",
        nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"}
    )

Your first error is a sampling error. It is likely because you did not define P0, so the kalman filter estimate covaraince was being initialized as the zero matrix.

The second error is a numba error. It was reported here. You can fix it by downgrading to numba==0.60.0 or waiting a couple days until we increase the version pin to pytensor on pymc (it’s already fixed in pytensor).

1 Like