Hi Alex,
You’re 99% of the way there, you’re just missing some syntactic idiosyncrasies of PyMC. Specifically:
-
As the error says, you can’t create new PyMC random variables in inner function of a scan. Instead, you have to make pytensor random variables, then tell PyMC to consider the entire sequence jointly as a single RV. To do this:
- Use
pm.Normal.dist
instead ofpm.Normal
. This is how to get the underlying pytensor RV. - Use the
collect_default_updates
helper function to handle the seeding of the random number generator through the scan - Use
pm.CustomDist
to tell PyMC your whole scan sequence is a joint RV.
- Use
-
Then, since you’re using a
CustomDist
, you have to split out thescan
into a helper function that will be used to create a joint distribution over the sequence of hidden state(s).
Admittedly, it’s a lot of boilerplate. But together, it will look like this:
import pytensor
import pytensor.tensor as pt
import pymc as pm
from pymc.pytensorf import collect_default_updates
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
# Generate artifical data
n_steps = 100
rng = np.random.default_rng(1234)
x0_true = rng.normal()
true_sigma = abs(rng.normal(scale=0.5))
true_irreg = abs(rng.normal(scale=0.5))
grw_innovations = rng.normal(scale=true_sigma, size=(100,))
data = rng.normal(np.r_[x0_true, grw_innovations].cumsum()[1:], true_irreg)
# Helper function for pm.CustomDist
def statespace_dist(mu_init, sigma_level, size):
def grw_step(mu_tm1, sigma_level):
mu_t = mu_tm1 + pm.Normal.dist(sigma=sigma_level)
return mu_t, collect_default_updates(outputs=[mu_t])
mu, updates = pytensor.scan(fn=grw_step,
outputs_info=[{"initial": mu_init}],
non_sequences=[sigma_level],
n_steps=n_steps,
name='statespace',
strict=True)
return mu
# PyMC Model
coords = {'time':np.arange(n_steps)}
with pm.Model(coords=coords) as model:
y_data = pm.MutableData('y_data', data, dims=['time'])
mu_init = pm.Normal('mu_init', mu=0, sigma=1)
sigma_level = pm.HalfNormal('sigma_level', sigma=1)
sigma_irreg = pm.HalfNormal('sigma_irreg', sigma=1)
mu = pm.CustomDist('hidden_states',
mu_init,
sigma_level,
dist=statespace_dist,
dims=['time'])
y_hat = pm.Normal('y_hat', mu, sigma_irreg, observed=y_data, dims=['time'])
idata = pm.sample()
Plot results:
fig, ax = plt.subplots(figsize=(14, 4))
x_grid = coords['time']
mu_data = idata.posterior.hidden_states
mu_hdi = az.hdi(mu_data).hidden_states
ax.plot(x_grid, mu_data.mean(dim=['chain', 'draw']), lw=2)
ax.fill_between(x_grid, *mu_hdi.values.T, alpha=0.5)
post_data = idata.posterior_predictive.y_hat
hdi = az.hdi(post_data).y_hat
# ax.plot(x_grid, post_data.mean(dim=['chain', 'draw']))
ax.fill_between(x_grid, *hdi.values.T, alpha=0.25, color='tab:blue')
ax.plot(data, color='k', ls='--', lw=0.5)
It’s good to work through all this to get a deep understanding for how PyMC works under the hood. But if you just want ready-to-go statespace models, you can check out the statespace module of pymc-experimental.
For more examples of writing scan-based time series models though, you can check here and here.