`pymc-experimental` now includes state spaces models!


PyMC now has experimental support for linear, Gaussian state space time series models via the pymc_experimental.statespace module. You can see example usages here, here, and here, along with a tutorial on making custom state space models here.

Get it with pip install pymc-experimental in your favorite (conda-installed) pymc environment!

Linear Gaussian State Space Models

Over one year (and two major releases!) ago, I made my first forum post here on discourse, asking about state space models. If you don’t know, these are a flexible family of models of the form:

\begin{align} x_t &= c_t + T_t x_{t-1} + R_t \varepsilon_t & \varepsilon &\sim N(0, Q_t) \\ y_t &= d_t + Z_t x_t + \eta_t & \eta_t &\sim N(0, H_t) \\ x_0 &\sim N(\bar{x}, P_0) \end{align}

Despite being linear and Gaussian, this form ends up very flexible, and admits representation of a huge range of time series models, including SARIMAX, VARMAX, exponential smoothing, and a combinatorial number of additive structural models.

These models are workhorses in a wide range of disciplines. They are, however, a bit cumbersome to represent in PyMC, requiring a lot of overhead with setting up matrices, writing scan Ops, and handling marginalization of hidden states.

As of today, that headache is over! You can easily work with Linear Gaussian State Space models directly in PyMC to perform parameter estimation, hidden state inference, missing data interpolation, forecasting, impulse response functions, and much more.

For all the details, see the notebooks linked in the tl;dr. To wet your appetite, though, let’s look at two representative examples of how a state space model is declared, estimated, and used.

API Introduction: ARIMA(1,1,1)

Instead of going through the API in detail, let’s dive straight into fitting a model. For “classical” time series models like SARIMA and VARMA, the statespace API tries to match that of statsmodels.tsa.statespace. That means that you will need to instantiate a statespace model before you start working on your PyMC model. The PyMCStateSpace model is going to hold all the bibs and bobs you need to implement the equations written above, and it will also hold a KalmanFilter object. That’s responsible for marginalizing over the hidden states of your model. This is one of the superpowers of the state-space setup – the power to make inferences about unobserved time series, given observed ones (and a structural model, of course!).

For this first example, we will use the daily closing price of Google stock since the beginning of 2022. The data look like this:

import yfinance as yf
goog_raw = yf.Ticker('GOOG').history(start='2022-01-21', end=None).Close
goog = goog_raw.resample('B').last()
goog.index = goog.index.tz_localize(None)

I re-sampled the data to business days, which is not technically correct, because it created missing values for days the market was closed on a weekday (mostly Mondays). But this is nice, because it will show off the automatic interpolation features of the statespace module.

Price data is known to be non-stationary, while returns are stationary. Thus, we can model the log price in first differences. This is easy accomplished with an ARIMA(p, 1, q) model. For simplicity, I will make an ARIMA(1,1,1). If you don’t know, the p is the number of autoregressive lags to include, the q is the number of innovation lags to include, and the middle number is the number of times the data needs to be differenced to render it “stationary”. See here for a textbook treatment on the subject.

import pymc_experimental.statespace as pmss
ss_mod = pmss.BayesianSARIMA(order=(1, 1, 1), stationary_initialization=False)
>>> Out: The following parameters should be assigned priors inside a PyMC model block: 
	    x0 -- shape: (3,), constraints: None, dims: ('state',)
	    P0 -- shape: (3, 3), constraints: Positive Semi-definite, dims: ('state', 'state_aux')
	    ar_params -- shape: (1,), constraints: None, dims: ('ar_lag',)
	    ma_params -- shape: (1,), constraints: None, dims: ('ma_lag',)
	    sigma_state -- shape: (1,), constraints: Positive, dims: ('observed_state',)

After making the model, we get a message telling us what we need to do inside a PyMC block (this message can be disabled by passing verbose=False to any state space model). Let’s follow the instructions, and assign priors to the 5 parameters requested: x0, P0, ar_params, ma_params, sigma_state.

with pm.Model(coords=ss_mod.coords) as pymc_model:
    x0_observed = pm.Laplace('x0_obs', mu=5, b=0.1, shape=(1,))
    x0 = pm.Deterministic('x0', pt.concatenate([x0_observed, pt.zeros(2,)]), dims=['state'])

    P0_diag = pm.HalfNormal('P0_diag', sigma=0.01, dims=['state'])
    P0 = pm.Deterministic('P0', pt.diag(P0_diag), dims=['state', 'state_aux'])
    ar_params = pm.Laplace('ar_params', mu=0, b=0.25, dims=['ar_lag'])
    ma_params = pm.Laplace('ma_params', mu=0, b=0.25, dims=['ma_lag'])
    sigma_state = pm.Gamma('sigma_state', alpha=2, beta=1, dims=['observed_state'])

    ss_mod.build_statespace_graph(goog.apply(np.log), mode='JAX')

Some things to note:

  1. All statespace models carry coordinates you can pass to pm.Model in their coords property. This will let us set object shapes using the dims suggested in the message we got on model construction.
  2. We are free to choose any priors we want, or even to construct hierarchical dependencies between priors! The only thing that matters is that the name of the final product matches the name asked for by the construction message. I demonstrate this is the construction of the priors on P0 and x0, which are admittedly overly complex.
  3. It is known that log returns of stocks do not exhibit ARMA dynamics, so I put very skeptical priors on ar_params and ma_params. Let’s see your Gibbs sampler do that!

The ss_mod.build_statespace_graph method is the key bridge between the PyMC model and the state space model. After you’re done declaring priors, call it to automatically load up the PyMC model with all the state space objects and Kalman filter outputs. It requires two pieces of information:

  1. The data you want to filter. In this case, we want to use the log price, because the 1st difference in our model will transform it to log returns.
  2. The mode you want to use to fit the model. In this case, we plan to use a JAX sampler, so we need to tell the state space model to compile all the internals into JAX mode.

If you look at the PyMC model you can see all the objects that have been loaded in:

\begin{array}{rcl} \text{x0_obs} &\sim & \operatorname{Laplace}(5,~0.1)\\\text{P0_diag} &\sim & \operatorname{HalfNormal}(0,~0.01)\\\text{ar_params} &\sim & \operatorname{Laplace}(0,~0.25)\\\text{ma_params} &\sim & \operatorname{Laplace}(0,~0.25)\\\text{sigma_state} &\sim & \operatorname{Gamma}(2,~f())\\\text{x0} &\sim & \operatorname{Deterministic}(f(\text{x0_obs}))\\\text{P0} &\sim & \operatorname{Deterministic}(f(\text{P0_diag}))\\\text{c} &\sim & \operatorname{Deterministic}(f())\\\text{d} &\sim & \operatorname{Deterministic}(f())\\\text{T} &\sim & \operatorname{Deterministic}(f(\text{ar_params}))\\\text{Z} &\sim & \operatorname{Deterministic}(f())\\\text{R} &\sim & \operatorname{Deterministic}(f(\text{ma_params}))\\\text{H} &\sim & \operatorname{Deterministic}(f())\\\text{Q} &\sim & \operatorname{Deterministic}(f(\text{sigma_state}))\\\text{filtered_state} &\sim & \operatorname{Deterministic}(f(\text{sigma_state},~\text{ma_params},~\text{ar_params},~\text{P0_diag},~\text{x0_obs}))\\\text{predicted_state} &\sim & \operatorname{Deterministic}(f(\text{x0_obs},~\text{ma_params},~\text{sigma_state},~\text{ar_params},~\text{P0_diag}))\\\text{filtered_covariance} &\sim & \operatorname{Deterministic}(f(\text{sigma_state},~\text{ma_params},~\text{ar_params},~\text{P0_diag},~\text{x0_obs}))\\\text{predicted_covariance} &\sim & \operatorname{Deterministic}(f(\text{P0_diag},~\text{ma_params},~\text{sigma_state},~\text{ar_params},~\text{x0_obs}))\\\text{smoothed_state} &\sim & \operatorname{Deterministic}(f(\text{sigma_state},~\text{ma_params},~\text{ar_params},~\text{P0_diag},~\text{x0_obs}))\\\text{smoothed_covariance} &\sim & \operatorname{Deterministic}(f(\text{sigma_state},~\text{ma_params},~\text{ar_params},~\text{P0_diag},~\text{x0_obs}))\\\text{obs} &\sim & \operatorname{SequenceMvNormal}(f(\text{sigma_state},~\text{ma_params},~\text{ar_params},~\text{P0_diag},~\text{x0_obs}),~f(\text{sigma_state},~\text{ma_params},~\text{ar_params},~\text{P0_diag},~\text{x0_obs}),~f(\text{sigma_state},~\text{ma_params},~\text{ar_params},~\text{P0_diag},~\text{x0_obs}),~f(\text{sigma_state},~\text{ma_params},~\text{ar_params},~\text{P0_diag},~\text{x0_obs})) \end{array}

To fit the model, call pm.sample as normal!

with pymc_model:
    idata = pm.sample(nuts_sampler='numpyro', target_accept=0.95)

From there, you can look your favorite diagnostic plots as normal. But beyond that, there is some special functionality available to you.

Conditional Posterior Sampling

You might be interested in the hidden states of the model, given the posteriors over the parameters. In our case, this can obtain estimates for log returns on days when the market was closed. You can get these as follows:

post = ss_mod.sample_conditional_posterior(idata)

This returns an xarray with 6 variables:

Dimensions:                       (chain: 4, draw: 1000, time: 412, state: 3,
                                   observed_state: 1)
  * chain                         (chain) int64 0 1 2 3
  * draw                          (draw) int64 0 1 2 3 4 ... 995 996 997 998 999
  * time                          (time) datetime64[ns] 2022-01-21 ... 2023-0...
  * state                         (state) <U12 'data' 'data_star' 'state_star_1'
  * observed_state                (observed_state) <U4 'data'
Data variables:
    filtered_posterior            (chain, draw, time, state) float64 4.744 .....
    filtered_posterior_observed   (chain, draw, time, observed_state) float64 ...
    predicted_posterior           (chain, draw, time, state) float64 4.758 .....
    predicted_posterior_observed  (chain, draw, time, observed_state) float64 ...
    smoothed_posterior            (chain, draw, time, state) float64 4.846 .....
    smoothed_posterior_observed   (chain, draw, time, observed_state) float64 ...
    created_at:                 2023-08-21T16:47:40.567049
    arviz_version:              0.16.1
    inference_library:          pymc
    inference_library_version:  5.7.2

For each Kalman filter output, you get all the hidden states, and the observed states. In this model, since there is no measurement error, the kalman filter and smoother noiselessly encode the observed data, so they’re not so interesting. We can, however, use them to look at missing values. The filter gives the best guess of the missing values using only information up to the missing value, while the smoother integrates future information to improve the guess.

Here are the Kalman Smoothed interpolations for the missing days, with 94% HDIs and data. Note that I plot the 0th state (the log price), and exp it to get back to the actual price.

fig, ax = plt.subplots()
data = post.smoothed_posterior.isel(state=0)
hdi = az.hdi(data).smoothed_posterior
ax.plot(goog.index, goog.values, label='Data', ls='--', color='tab:red')
ax.plot(goog.index, np.exp(data.mean(dim=['chain', 'draw']).values), label='Smoothed Mean')
ax.fill_between(goog.index, *np.exp(hdi.values).T, alpha=0.8, color='tab:orange', label='HDI 94%')


Another common post-estimation task is forecasting. This is easy in the state space framework, we just roll those two matrix equations forward from the last posterior hidden states. Here’s a 100-day forecast for the google stock price:

forecast = ss_mod.forecast(idata, start=goog.index[-1], periods=100)

fig, ax = plt.subplots()
ax.plot(goog.index, goog.values)

data = forecast.forecast_observed.isel(observed_state=0).stack(sample=['chain', 'draw']).values
ax.plot(forecast.coords['time'], np.exp(data), color='0.5', alpha=0.1);

As you can see, the best stock price forecast is ¯\_(ツ)_/¯. You either believe Eugene Fama, or you believe we need a better model. Maybe a bit of both. But I hope you can see that forecasting is a breeze in the statespace framework!

Another Example: Structural Modeling

In addition to statsmodels.tsa.statespace-stype models, pymc_experimental.statespace also includes a suite of “building blocks” to construct custom state space models. This is all explained in detail in the structural modeling example notebook. To get your interested, though, let’s look at an example. Here is the well-known airpass.csv dataset:

Unlike in the SARIMA/VARMA modeling framework, we do not require stationarity to do structural modeling. Here, we have at least 4 sources of non-stationarity:

  1. There is a non-zero mean
  2. There is a deterministic trend
  3. There is a seasonal pattern
  4. The intensity of the seasonal pattern is increasing over time

We can capture all of these features of the time series using structural components. These are found in the pymc_experimental.statespace.structural module. We will use three components:

  1. A non-stationary trend component, structural.LevelTrendComponent
  2. A frequency seasonal component, structural.FrequencySeasonality
  3. A serially-correlated error term, structural.AutoRegressive
import pymc_experimental.statespace.structural as st

# Make the components. 
# Order=2 means we want our time series to have a level (position) and a trend (velocity). See the docs for details.
ll = st.LevelTrendComponent(order=2)

# For the frequency component, we need to specify the length of a season. 
# We have monthly data with an annual pattern, so it's 12 here.
se = st.FrequencySeasonality(season_length=12, name="annual")

# We can also add autocorrelated errors to the model using an AutoRegressive block
ar = st.AutoregressiveComponent(order=1)

# Components are added together like Lego blocks to incrementally build your model
mod = ll + ar + se

# When you're done adding, call the `.build()` to convert the components to a StateSpace model! 
ss_mod = mod.build()
>>>Out: The following parameters should be assigned priors inside a PyMC model block: 
	    initial_trend -- shape: (2,), constraints: None, dims: ('trend_state',)
	    sigma_trend -- shape: (2,), constraints: Positive, dims: ('trend_shock',)
	    ar_params -- shape: (1,), constraints: None, dims: (ar_lags, )
	    sigma_ar -- shape: (1,), constraints: Positive, dims: None
	    annual -- shape: (11,), constraints: None, dims: (annual_initial_state, )
	    sigma_annual -- shape: (1,), constraints: Positive, dims: None
	    P0 -- shape: (15, 15), constraints: Positive semi-definite, dims: ('state', 'state_aux')

As before, we get a message telling us what priors and dims to make. Now things are a bit more complex, because both blocks are contributing parameters and dims. Still, we just following the instructions to make a PyMC model, then call ss_mod.build_statespace_graph!

with pm.Model(coords=ss_mod.coords) as model_2:
    P0_diag = pm.Gamma("P0_diag", alpha=2, beta=1, dims=['state'])
    P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=['state', 'state_aux'])
    initial_trend = pm.Normal("initial_trend", sigma=[100, 1], dims=['trend_state'])
    annual_seasonal = pm.Normal("annual", sigma=100, dims=['annual_initial_state'])
    ar_params = pm.Laplace('ar_params', mu=0, b=0.2, dims=['ar_lags'])

    sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=1, dims=['trend_shock'])
    sigma_monthly_season = pm.Gamma("sigma_annual", alpha=2, beta=1)
    sigma_ar = pm.HalfNormal("sigma_ar", sigma=0.5)

    ss_mod.build_statespace_graph(airpass, mode="JAX")
    idata = pm.sample(nuts_sampler="numpyro")

Posterior Components

Unlike in the ARIMA example, we might be interested in the hidden states of the structural model. For example, we might want to just isolate the trend component of the model. For this, StateSpace models have a helper method, extract_components_from_idata. First, get the conditional posterior hidden states:

post = ss_mod.sample_conditional_posterior(idata)

Next, give these hidden states to ss_mod.extract_components_from_idata:

component_idata = ss_mod.extract_components_from_idata(post)

# Check the names of what we got back
component_states = component_idata.coords["state"].values.tolist()
>>> Out: ['LevelTrend[level]', 'LevelTrend[trend]', 'AutoRegressive', 'annual']

We took the 15 hidden states of the model and boiled them down to just 4. Now we can visualize the contributions of each to the model:

fig, ax = plt.subplots(4, 1, figsize=(14, 9))
x_values = component_idata.coords["time"]
for axis, name in zip(fig.axes, component_states):
    data = component_idata.predicted_posterior.sel(state=name)
    hdi = az.hdi(data).predicted_posterior
    mean = data.mean(dim=["chain", "draw"])

    axis.plot(x_values, mean)
    axis.fill_between(x_values, *hdi.values.T, color="tab:blue", alpha=0.1)

Evidently, the model didn’t have much use for the autoregressive component! But we can see that the seasonal component was isolated, and the increasing variance is accounted for. There is a smooth linear increase in airline passangers over the sample period, with some brief trend reversals in 1954 and 1958.

Of course, we can also look at the posterior predicted observed states to check the accuracy of one-step-ahead forecasts over the whole sample period:

fig, ax = plt.subplots()
post_stacked = post.stack(sample=["chain", "draw"])
x_values = post_stacked.coords["time"]
hdi_post = az.hdi(post)

    label='Posterior One-Step Ahead Predictions'
ax.fill_between(x_values, *hdi_post.predicted_posterior_observed.isel(observed_state=0).values.T, alpha=0.25,
               label='HDI 94%')
ax.plot(airpass.index, airpass.values, label='Data')

And, of course, make forecasts:

forecasts = ss_mod.forecast(idata, start=airpass.index[-1], periods=24)
forecasts = forecasts.stack(sample=["chain", "draw"])

fig, ax = plt.subplots()
ax.plot(airpass.index, airpass, label='Data')
    label='Posterior Forecasts'
    label = 'Mean Posterior Forecast'

handles, labels = ax.get_legend_handles_labels()
labels, ids = np.unique(labels, return_index=True)
handles = [handles[i] for i in ids]
ax.legend(handles, labels, loc='best')

(These forecasts were made with a different model that excluded the autoregressive component)

Final Thoughts

There is a lot more you can do with the statespace package. SARIMAX and VARMAX models are supported, as are many more types of components in the structural module. There is, of course, still a lot of work to do. I started an issue tracker here if you are interested in getting involved or if you want to let us know what missing features are important to your use-case.

There are also limitations. Performance is an issue. I do not recommend running these models on the default NUTS sampler, because you will not have a good time. Sorry Windows gang. Numba support should come eventually for fast, cross-platform inference using nutpie. Regardless of sampler, for large state spaces or long time series, expect sampling to be quite slow. We will be working to improve this.

Overall, though, state space models are a great way to analyze your time-series data, and the pymc_experimental.statespace module offers tools to make model construction, inference, and post-estimation as easy as possible. They allow for the interpolation of missing data, even for multivariate outputs, as well as inference of unobserved time series. As long as you don’t need non-linear transition dynamics or non-Gaussian innovations, they offer a one-stop shop for all your time series needs!

If you do give it a try, be sure to let the team know what worked and didn’t work so we can keep iterating on the module. And, if you run into problems, don’t hesitate to ask here on the discourse. Happy modeling!


This is awesome!

I am confused by this line:

Is the ‘they’ in this sentence referring only to state-space models in this package? Because we can build a state-space model with either or both of non-linear transition dynamics and non-Gaussian innovations and still use pymc for inference, right?

1 Like

Yeah, that sentence is meant to apply only to the state space models in this package, since they do hidden state inference with vanilla Kalman filtering. The usual PyMC machinery is still available/unchanged for any other cases.

1 Like