Declaring Priors using Loops for State Space Models

Hi Alex,

You’re 99% of the way there, you’re just missing some syntactic idiosyncrasies of PyMC. Specifically:

  1. As the error says, you can’t create new PyMC random variables in inner function of a scan. Instead, you have to make pytensor random variables, then tell PyMC to consider the entire sequence jointly as a single RV. To do this:

    • Use pm.Normal.dist instead of pm.Normal. This is how to get the underlying pytensor RV.
    • Use the collect_default_updates helper function to handle the seeding of the random number generator through the scan
    • Use pm.CustomDist to tell PyMC your whole scan sequence is a joint RV.
  2. Then, since you’re using a CustomDist, you have to split out the scan into a helper function that will be used to create a joint distribution over the sequence of hidden state(s).

Admittedly, it’s a lot of boilerplate. But together, it will look like this:

import pytensor
import pytensor.tensor as pt
import pymc as pm
from pymc.pytensorf import collect_default_updates
import numpy as np
import matplotlib.pyplot as plt
import arviz as az

# Generate artifical data
n_steps = 100
rng = np.random.default_rng(1234)
x0_true = rng.normal()
true_sigma = abs(rng.normal(scale=0.5))
true_irreg = abs(rng.normal(scale=0.5))
grw_innovations = rng.normal(scale=true_sigma, size=(100,))
data = rng.normal(np.r_[x0_true, grw_innovations].cumsum()[1:], true_irreg)


# Helper function for pm.CustomDist
def statespace_dist(mu_init, sigma_level, size):

    def grw_step(mu_tm1, sigma_level):
        mu_t = mu_tm1 + pm.Normal.dist(sigma=sigma_level)
        return mu_t, collect_default_updates(outputs=[mu_t])

    mu, updates = pytensor.scan(fn=grw_step, 
                                outputs_info=[{"initial": mu_init}],
                                non_sequences=[sigma_level], 
                                n_steps=n_steps,
                                name='statespace',
                                strict=True)

    return mu


# PyMC Model
coords = {'time':np.arange(n_steps)}

with pm.Model(coords=coords) as model:           
    y_data = pm.MutableData('y_data', data, dims=['time'])

    mu_init = pm.Normal('mu_init', mu=0, sigma=1)
    sigma_level = pm.HalfNormal('sigma_level', sigma=1)
    sigma_irreg = pm.HalfNormal('sigma_irreg', sigma=1)

    mu = pm.CustomDist('hidden_states', 
                          mu_init,
                          sigma_level,
                          dist=statespace_dist,
                          dims=['time'])
    y_hat = pm.Normal('y_hat', mu, sigma_irreg, observed=y_data, dims=['time'])    
    idata = pm.sample()

Plot results:

fig, ax = plt.subplots(figsize=(14, 4))
x_grid = coords['time']

mu_data = idata.posterior.hidden_states
mu_hdi = az.hdi(mu_data).hidden_states
ax.plot(x_grid, mu_data.mean(dim=['chain', 'draw']), lw=2)
ax.fill_between(x_grid, *mu_hdi.values.T, alpha=0.5)


post_data = idata.posterior_predictive.y_hat
hdi = az.hdi(post_data).y_hat
# ax.plot(x_grid, post_data.mean(dim=['chain', 'draw']))
ax.fill_between(x_grid, *hdi.values.T, alpha=0.25, color='tab:blue')
ax.plot(data, color='k', ls='--', lw=0.5)

It’s good to work through all this to get a deep understanding for how PyMC works under the hood. But if you just want ready-to-go statespace models, you can check out the statespace module of pymc-experimental.

For more examples of writing scan-based time series models though, you can check here and here.

4 Likes