### tl;dr

PyMC now has experimental support for linear, Gaussian state space time series models via the `pymc_experimental.statespace`

module. You can see example usages here, here, and here, along with a tutorial on making custom state space models here.

Get it with `pip install pymc-experimental`

in your favorite (conda-installed) `pymc`

environment!

# Linear Gaussian State Space Models

Over one year (and two major releases!) ago, I made my first forum post here on discourse, asking about state space models. If you donât know, these are a flexible family of models of the form:

Despite being linear and Gaussian, this form ends up very flexible, and admits representation of a huge range of time series models, including SARIMAX, VARMAX, exponential smoothing, and a combinatorial number of additive structural models.

These models are workhorses in a wide range of disciplines. They are, however, a bit cumbersome to represent in PyMC, requiring a lot of overhead with setting up matrices, writing `scan`

`Op`

s, and handling marginalization of hidden states.

As of today, that headache is over! You can easily work with Linear Gaussian State Space models directly in PyMC to perform parameter estimation, hidden state inference, missing data interpolation, forecasting, impulse response functions, and much more.

For all the details, see the notebooks linked in the tl;dr. To wet your appetite, though, letâs look at two representative examples of how a state space model is declared, estimated, and used.

# API Introduction: ARIMA(1,1,1)

Instead of going through the API in detail, letâs dive straight into fitting a model. For âclassicalâ time series models like SARIMA and VARMA, the `statespace`

API tries to match that of `statsmodels.tsa.statespace`

. That means that you will need to instantiate a statespace model before you start working on your PyMC model. The `PyMCStateSpace`

model is going to hold all the bibs and bobs you need to implement the equations written above, and it will also hold a `KalmanFilter`

object. Thatâs responsible for marginalizing over the hidden states of your model. This is one of the superpowers of the state-space setup â the power to make inferences about unobserved time series, given observed ones (and a structural model, of course!).

For this first example, we will use the daily closing price of Google stock since the beginning of 2022. The data look like this:

```
import yfinance as yf
goog_raw = yf.Ticker('GOOG').history(start='2022-01-21', end=None).Close
goog = goog_raw.resample('B').last()
goog.index = goog.index.tz_localize(None)
goog.plot()
```

I re-sampled the data to business days, which is not technically correct, because it created missing values for days the market was closed on a weekday (mostly Mondays). But this is nice, because it will show off the automatic interpolation features of the `statespace`

module.

Price data is known to be non-stationary, while returns are stationary. Thus, we can model the log price in first differences. This is easy accomplished with an ARIMA(p, 1, q) model. For simplicity, I will make an ARIMA(1,1,1). If you donât know, the `p`

is the number of autoregressive lags to include, the `q`

is the number of innovation lags to include, and the middle number is the number of times the data needs to be differenced to render it âstationaryâ. See here for a textbook treatment on the subject.

```
import pymc_experimental.statespace as pmss
ss_mod = pmss.BayesianSARIMA(order=(1, 1, 1), stationary_initialization=False)
```

```
>>> Out: The following parameters should be assigned priors inside a PyMC model block:
x0 -- shape: (3,), constraints: None, dims: ('state',)
P0 -- shape: (3, 3), constraints: Positive Semi-definite, dims: ('state', 'state_aux')
ar_params -- shape: (1,), constraints: None, dims: ('ar_lag',)
ma_params -- shape: (1,), constraints: None, dims: ('ma_lag',)
sigma_state -- shape: (1,), constraints: Positive, dims: ('observed_state',)
```

After making the model, we get a message telling us what we need to do inside a PyMC block (this message can be disabled by passing `verbose=False`

to any state space model). Letâs follow the instructions, and assign priors to the 5 parameters requested: `x0, P0, ar_params, ma_params, sigma_state`

.

```
with pm.Model(coords=ss_mod.coords) as pymc_model:
x0_observed = pm.Laplace('x0_obs', mu=5, b=0.1, shape=(1,))
x0 = pm.Deterministic('x0', pt.concatenate([x0_observed, pt.zeros(2,)]), dims=['state'])
P0_diag = pm.HalfNormal('P0_diag', sigma=0.01, dims=['state'])
P0 = pm.Deterministic('P0', pt.diag(P0_diag), dims=['state', 'state_aux'])
ar_params = pm.Laplace('ar_params', mu=0, b=0.25, dims=['ar_lag'])
ma_params = pm.Laplace('ma_params', mu=0, b=0.25, dims=['ma_lag'])
sigma_state = pm.Gamma('sigma_state', alpha=2, beta=1, dims=['observed_state'])
ss_mod.build_statespace_graph(goog.apply(np.log), mode='JAX')
```

Some things to note:

- All statespace models carry coordinates you can pass to
`pm.Model`

in their`coords`

property. This will let us set object shapes using the`dims`

suggested in the message we got on model construction. - We are free to choose any priors we want, or even to construct hierarchical dependencies between priors! The only thing that matters is that the name of the final product matches the name asked for by the construction message. I demonstrate this is the construction of the priors on
`P0`

and`x0`

, which are admittedly overly complex. - It is known that log returns of stocks do not exhibit ARMA dynamics, so I put very skeptical priors on
`ar_params`

and`ma_params`

. Letâs see your Gibbs sampler do that!

The `ss_mod.build_statespace_graph`

method is the key bridge between the PyMC model and the state space model. After youâre done declaring priors, call it to automatically load up the PyMC model with all the state space objects and Kalman filter outputs. It requires two pieces of information:

- The data you want to filter. In this case, we want to use the
**log**price, because the 1st difference in our model will transform it to log returns. - The mode you want to use to fit the model. In this case, we plan to use a JAX sampler, so we need to tell the state space model to compile all the internals into JAX mode.

If you look at the PyMC model you can see all the objects that have been loaded in:

```
pymc_model
```

To fit the model, call `pm.sample`

as normal!

```
with pymc_model:
idata = pm.sample(nuts_sampler='numpyro', target_accept=0.95)
```

From there, you can look your favorite diagnostic plots as normal. But beyond that, there is some special functionality available to you.

### Conditional Posterior Sampling

You might be interested in the hidden states of the model, given the posteriors over the parameters. In our case, this can obtain estimates for log returns on days when the market was closed. You can get these as follows:

```
post = ss_mod.sample_conditional_posterior(idata)
```

This returns an `xarray`

with 6 variables:

```
print(post)
<xarray.Dataset>
Dimensions: (chain: 4, draw: 1000, time: 412, state: 3,
observed_state: 1)
Coordinates:
* chain (chain) int64 0 1 2 3
* draw (draw) int64 0 1 2 3 4 ... 995 996 997 998 999
* time (time) datetime64[ns] 2022-01-21 ... 2023-0...
* state (state) <U12 'data' 'data_star' 'state_star_1'
* observed_state (observed_state) <U4 'data'
Data variables:
filtered_posterior (chain, draw, time, state) float64 4.744 .....
filtered_posterior_observed (chain, draw, time, observed_state) float64 ...
predicted_posterior (chain, draw, time, state) float64 4.758 .....
predicted_posterior_observed (chain, draw, time, observed_state) float64 ...
smoothed_posterior (chain, draw, time, state) float64 4.846 .....
smoothed_posterior_observed (chain, draw, time, observed_state) float64 ...
Attributes:
created_at: 2023-08-21T16:47:40.567049
arviz_version: 0.16.1
inference_library: pymc
inference_library_version: 5.7.2
```

For each Kalman filter output, you get all the hidden states, and the observed states. In this model, since there is no measurement error, the kalman filter and smoother noiselessly encode the observed data, so theyâre not so interesting. We can, however, use them to look at missing values. The filter gives the best guess of the missing values using only information up to the missing value, while the smoother integrates future information to improve the guess.

Here are the Kalman Smoothed interpolations for the missing days, with 94% HDIs and data. Note that I plot the 0th state (the log price), and exp it to get back to the actual price.

```
fig, ax = plt.subplots()
data = post.smoothed_posterior.isel(state=0)
hdi = az.hdi(data).smoothed_posterior
ax.plot(goog.index, goog.values, label='Data', ls='--', color='tab:red')
ax.plot(goog.index, np.exp(data.mean(dim=['chain', 'draw']).values), label='Smoothed Mean')
ax.fill_between(goog.index, *np.exp(hdi.values).T, alpha=0.8, color='tab:orange', label='HDI 94%')
ax.legend()
plt.show()
```

## Forecasting

Another common post-estimation task is forecasting. This is easy in the state space framework, we just roll those two matrix equations forward from the last posterior hidden states. Hereâs a 100-day forecast for the google stock price:

```
forecast = ss_mod.forecast(idata, start=goog.index[-1], periods=100)
fig, ax = plt.subplots()
ax.plot(goog.index, goog.values)
data = forecast.forecast_observed.isel(observed_state=0).stack(sample=['chain', 'draw']).values
ax.plot(forecast.coords['time'], np.exp(data), color='0.5', alpha=0.1);
```

As you can see, the best stock price forecast is `ÂŻ\_(ă)_/ÂŻ`

. You either believe Eugene Fama, or you believe we need a better model. Maybe a bit of both. But I hope you can see that forecasting is a breeze in the `statespace`

framework!

## Another Example: Structural Modeling

In addition to `statsmodels.tsa.statespace`

-stype models, `pymc_experimental.statespace`

also includes a suite of âbuilding blocksâ to construct custom state space models. This is all explained in detail in the structural modeling example notebook. To get your interested, though, letâs look at an example. Here is the well-known `airpass.csv`

dataset:

Unlike in the SARIMA/VARMA modeling framework, we do not require stationarity to do structural modeling. Here, we have at least 4 sources of non-stationarity:

- There is a non-zero mean
- There is a deterministic trend
- There is a seasonal pattern
- The intensity of the seasonal pattern is increasing over time

We can capture all of these features of the time series using structural components. These are found in the `pymc_experimental.statespace.structural`

module. We will use three components:

- A non-stationary trend component,
`structural.LevelTrendComponent`

- A frequency seasonal component,
`structural.FrequencySeasonality`

- A serially-correlated error term,
`structural.AutoRegressive`

```
import pymc_experimental.statespace.structural as st
# Make the components.
# Order=2 means we want our time series to have a level (position) and a trend (velocity). See the docs for details.
ll = st.LevelTrendComponent(order=2)
# For the frequency component, we need to specify the length of a season.
# We have monthly data with an annual pattern, so it's 12 here.
se = st.FrequencySeasonality(season_length=12, name="annual")
# We can also add autocorrelated errors to the model using an AutoRegressive block
ar = st.AutoregressiveComponent(order=1)
# Components are added together like Lego blocks to incrementally build your model
mod = ll + ar + se
# When you're done adding, call the `.build()` to convert the components to a StateSpace model!
ss_mod = mod.build()
```

```
>>>Out: The following parameters should be assigned priors inside a PyMC model block:
initial_trend -- shape: (2,), constraints: None, dims: ('trend_state',)
sigma_trend -- shape: (2,), constraints: Positive, dims: ('trend_shock',)
ar_params -- shape: (1,), constraints: None, dims: (ar_lags, )
sigma_ar -- shape: (1,), constraints: Positive, dims: None
annual -- shape: (11,), constraints: None, dims: (annual_initial_state, )
sigma_annual -- shape: (1,), constraints: Positive, dims: None
P0 -- shape: (15, 15), constraints: Positive semi-definite, dims: ('state', 'state_aux')
```

As before, we get a message telling us what priors and dims to make. Now things are a bit more complex, because both blocks are contributing parameters and dims. Still, we just following the instructions to make a PyMC model, then call `ss_mod.build_statespace_graph`

!

```
with pm.Model(coords=ss_mod.coords) as model_2:
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=1, dims=['state'])
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=['state', 'state_aux'])
initial_trend = pm.Normal("initial_trend", sigma=[100, 1], dims=['trend_state'])
annual_seasonal = pm.Normal("annual", sigma=100, dims=['annual_initial_state'])
ar_params = pm.Laplace('ar_params', mu=0, b=0.2, dims=['ar_lags'])
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=1, dims=['trend_shock'])
sigma_monthly_season = pm.Gamma("sigma_annual", alpha=2, beta=1)
sigma_ar = pm.HalfNormal("sigma_ar", sigma=0.5)
ss_mod.build_statespace_graph(airpass, mode="JAX")
idata = pm.sample(nuts_sampler="numpyro")
```

### Posterior Components

Unlike in the ARIMA example, we might be interested in the hidden states of the structural model. For example, we might want to just isolate the trend component of the model. For this, `StateSpace`

models have a helper method, `extract_components_from_idata`

. First, get the conditional posterior hidden states:

```
post = ss_mod.sample_conditional_posterior(idata)
```

Next, give these hidden states to `ss_mod.extract_components_from_idata`

:

```
component_idata = ss_mod.extract_components_from_idata(post)
# Check the names of what we got back
component_states = component_idata.coords["state"].values.tolist()
component_states
>>> Out: ['LevelTrend[level]', 'LevelTrend[trend]', 'AutoRegressive', 'annual']
```

We took the 15 hidden states of the model and boiled them down to just 4. Now we can visualize the contributions of each to the model:

```
fig, ax = plt.subplots(4, 1, figsize=(14, 9))
x_values = component_idata.coords["time"]
for axis, name in zip(fig.axes, component_states):
data = component_idata.predicted_posterior.sel(state=name)
hdi = az.hdi(data).predicted_posterior
mean = data.mean(dim=["chain", "draw"])
axis.plot(x_values, mean)
axis.fill_between(x_values, *hdi.values.T, color="tab:blue", alpha=0.1)
axis.set_title(name)
plt.show()
```

Evidently, the model didnât have much use for the autoregressive component! But we can see that the seasonal component was isolated, and the increasing variance is accounted for. There is a smooth linear increase in airline passangers over the sample period, with some brief trend reversals in 1954 and 1958.

Of course, we can also look at the posterior predicted observed states to check the accuracy of one-step-ahead forecasts over the whole sample period:

```
fig, ax = plt.subplots()
post_stacked = post.stack(sample=["chain", "draw"])
x_values = post_stacked.coords["time"]
hdi_post = az.hdi(post)
ax.plot(
x_values,
post_stacked.predicted_posterior_observed.isel(observed_state=0).mean(dim="sample"),
label='Posterior One-Step Ahead Predictions'
)
ax.fill_between(x_values, *hdi_post.predicted_posterior_observed.isel(observed_state=0).values.T, alpha=0.25,
label='HDI 94%')
ax.plot(airpass.index, airpass.values, label='Data')
ax.legend()
plt.show()
```

And, of course, make forecasts:

```
forecasts = ss_mod.forecast(idata, start=airpass.index[-1], periods=24)
forecasts = forecasts.stack(sample=["chain", "draw"])
fig, ax = plt.subplots()
ax.plot(airpass.index, airpass, label='Data')
ax.plot(
forecasts.coords["time"],
forecasts.forecast_observed.values.squeeze(),
color="tab:orange",
alpha=0.1,
label='Posterior Forecasts'
)
ax.plot(
forecasts.coords["time"],
forecasts.forecast_observed.mean(dim="sample").values.squeeze(),
color="k",
alpha=1,
label = 'Mean Posterior Forecast'
)
handles, labels = ax.get_legend_handles_labels()
labels, ids = np.unique(labels, return_index=True)
handles = [handles[i] for i in ids]
ax.legend(handles, labels, loc='best')
plt.show()
```

(These forecasts were made with a different model that excluded the autoregressive component)

## Final Thoughts

There is a lot more you can do with the statespace package. SARIMAX and VARMAX models are supported, as are many more types of components in the structural module. There is, of course, still a lot of work to do. I started an issue tracker here if you are interested in getting involved or if you want to let us know what missing features are important to your use-case.

There are also limitations. Performance is an issue. I do not recommend running these models on the default NUTS sampler, because you will not have a good time. Sorry Windows gang. Numba support should come eventually for fast, cross-platform inference using `nutpie`

. Regardless of sampler, for large state spaces or long time series, expect sampling to be quite slow. We will be working to improve this.

Overall, though, state space models are a great way to analyze your time-series data, and the `pymc_experimental.statespace`

module offers tools to make model construction, inference, and post-estimation as easy as possible. They allow for the interpolation of missing data, even for multivariate outputs, as well as inference of unobserved time series. As long as you donât need non-linear transition dynamics or non-Gaussian innovations, they offer a one-stop shop for all your time series needs!

If you do give it a try, be sure to let the team know what worked and didnât work so we can keep iterating on the module. And, if you run into problems, donât hesitate to ask here on the discourse. Happy modeling!