State Space Models in PyMC

The PyMCStateSpace (base class for any state space model) class now looks like this:

class PyMCStateSpace:
    def __init__(self, data, k_states, k_posdef):
        self.data = data
        self.n_obs, self.k_endog = data.shape
        self.k_states = k_states
        self.k_posdef = k_posdef

        # All models contain a state space representation and a Kalman filter
        self.ssm = AesaraRepresentation(data, k_states, k_posdef)
        self.kalman_filter = KalmanFilter()

        # Placeholders for the aesara functions that will return the predicted state, covariance, and log likelihood
        # given parameter vector theta

        self.log_likelihood = None
        self.filtered_states = None
        self.filtered_covarainces = None

    def unpack_statespace(self):
        a0 = self.ssm['initial_state']
        P0 = self.ssm['initial_state_cov']
        Q = self.ssm['state_cov']
        H = self.ssm['obs_cov']
        T = self.ssm['transition']
        R = self.ssm['selection']
        Z = self.ssm['design']

        return a0, P0, Q, H, T, R, Z

    def build_statespace_graph(self, theta: at.TensorVariable) -> None:
        self._clear_existing_graphs()
        self.update(theta)
        states, covariances, log_likelihood = self.kalman_filter.build_graph(self.data, *self.unpack_statespace())

        self.log_likelihood = log_likelihood
        self.filtered_states = states
        self.filtered_covarainces = covariances

The build_statespace_graph takes in theta (which has the random variables generated by the PyMC model), and constructs the state-space side of the graph (machinery for updating the matrices + Kalman filter equations). No more aesera.function or at.grad calls.

I did add this helper method _clear_existing_graphs() that just checks if any of log_likelihood, filtered_states, or filtered_covariances are not None, and deletes them if so. I ran into issues with changing priors in the PyMC model then trying to run the code again without first resetting the Kernel; the state space side of the graph would still look for the old random variable. This is supposed to fix that. I don’t know if there is a helper function in aesara to flush any existing graphs/do this more elegantly.

Here’s how it now looks in a model block:

with pm.Model(coords=coords) as nile_model:
    state_sigmas = pm.HalfNormal('state_sigma', sigma=1.0, dims=['states'])
    obs_sigma = pm.HalfNormal('obs_sigma', sigma=1.0, dims=['obs'])

    x0 = pm.Normal('initial_states', mu=0.0, sigma=1.0, dims=['states'])
    initial_sigma = pm.HalfNormal('initial_sigma', sigma=5.0, dims=['states'])
    
    P0 = np.eye(2) * initial_sigma
    
    theta = at.concatenate([x0.ravel(), P0.ravel(), obs_sigma.ravel(), state_sigmas.ravel()])
    model.build_statespace_graph(theta)

    likelihood = pm.Potential("likelihood", model.log_likelihood)
    y_hat = pm.Deterministic('y_hat', model.filtered_states)
    cov_hat = pm.Deterministic('cov_hat', model.filtered_covarainces)
    
    prior_params = pm.sample_prior_predictive(var_names=['y_hat', 'cov_hat'])

Note that you have to call model.build_statespace_graph after setting all the priors up, and pass in a flat vector of random variables. Like I said, it would be nice if we could internally map names to indexes of the state space matrices and have a function that would use the names of the random variables to get everything to the right place, rather than requiring the user to flatten and concatenate everything.

New code and an updated notebook are available in the repo I linked above. I’m working on an ARMA (p,q) model next, then expand to ARIMA, then SARIMAX. Still need to iron out shape wrinkles in the Kalman filter too.

EDIT:

I added an ARMA class to estimate ARMA(p,q) models, with an example notebook here. In principal you can do ARIMA with this as well, but you have to handle the differencing yourself. The initial states are still causing problems, and it will be difficult to fit higher-order models without implementing re-parameterizations that constrain the weights to be stationary (all eigenvalues < 1 in modulus).

Nothing worth doing is easy, I guess.

6 Likes