In the code below, sample_prior_predictive(compile_kwargs={“mode”: “JAX”}) fails for the time-varying model. (Simple one-state model with exponential decay and a source term which may or may not change with time.)
Also, sample_unconditional_prior consistently complains about …
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: Subtensor{i, :, j}(state_cov{}, 0, -1)
…but the code seems to run fine despite that. Is this something to be concerned about?
import jax
jax.config.update("jax_platform_name", "cpu")
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import arviz as az
import pytensor
import pytensor.tensor as pt
import pymc as pm
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc_extras.statespace.models.utilities import make_default_coords
from pymc_extras.statespace.core import PyMCStateSpace
# # this fails (Array slice indices must have static start/stop/step to be used with NumPy indexing syntax, among other things)
# time_varying = True
# sample_with_JAX = True
# this runs, but sample_unconditional_prior throws "Rewrite failure due to: constant_folding" errors
time_varying = True
sample_with_JAX = False
# # this runs, but throws "Rewrite failure due to: constant_folding" errors
# time_varying = False
# sample_with_JAX = True
# # this runs, but throws "Rewrite failure due to: constant_folding" errors
# time_varying = False
# sample_with_JAX = False
class SimpleModel(PyMCStateSpace):
def __init__(self, n_timesteps=100):
k_states = 1
k_posdef = 0
k_endog = 1
self.n_timesteps = n_timesteps
super().__init__(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef)
def make_symbolic_graph(self):
# sets the matrix & vector values
x0 = self.make_and_register_variable('x0', shape=())
P0 = self.make_and_register_variable('P0', shape=())
self.ssm["initial_state",:] = x0
self.ssm['initial_state_cov',:] = P0
self.ssm["design", 0, 0] = 1
self.ssm["transition", 0, 0] = 0.95
if time_varying:
source = self.make_and_register_variable("source", shape=(self.n_timesteps,))
S_I = pt.zeros((self.n_timesteps, self.k_states))
S_I = S_I[:, 0].set(source)
self.ssm["state_intercept"] = S_I
else:
source = self.make_and_register_variable("source", shape=())
self.ssm["state_intercept", 0] = source
@property
def param_names(self):
params = ["x0", "P0", "source"]
return params
@property
def state_names(self):
return ['x1_latent']
@property
def shock_names(self):
return []
@property
def observed_states(self):
return ['x1']
@property
def coords(self):
return make_default_coords(self)
N = 100
if time_varying:
data_vec = np.linspace(0.5, 2, N) + np.random.normal(0, 0.1, size=N)
else:
data_vec = np.random.normal(0.5, 0.1, size=N)
data = pd.DataFrame(data_vec, index=pd.RangeIndex(N), columns=['x1'])
sm = SimpleModel(n_timesteps=N)
with pm.Model() as pymc_mod:
x0 = pm.Normal("x0", mu=0, sigma=1)
P0 = pm.HalfNormal("P0", sigma=0.01)
if time_varying:
mu = pm.Normal('mu', mu=0, sigma=1, shape=(2,))
source = pm.Deterministic('source', pt.linspace(mu[0], mu[1], N))
else:
mu = pm.Normal('mu', mu=0, sigma=1)
source = pm.Deterministic('source', mu)
sm.build_statespace_graph(data=data, mode="JAX")
with pymc_mod:
if sample_with_JAX:
prior = pm.sample_prior_predictive(compile_kwargs={"mode": "JAX"})
else:
prior = pm.sample_prior_predictive()
uncond_prior = sm.sample_unconditional_prior(prior, steps=100)
prior_obs = uncond_prior.prior_observed.stack(sample=["chain", "draw"])
fig, ax = plt.subplots(figsize=(14, 4), dpi=144)
(
prior_obs.sel(prior_observed_dim_1=0).plot.line(
x="prior_observed_dim_0", add_legend=False, ax=ax, color="0.5", alpha=0.05
)
)
plt.scatter(np.linspace(0,N,N), data_vec)
plt.show()
with pymc_mod:
idata = pm.sample(
nuts_sampler="nutpie",
nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"},
cores=8,
)
az.plot_trace(idata, var_names=["x0", 'mu'])
plt.show()
uncond_post = sm.sample_unconditional_posterior(idata, steps=100)
post_obs = uncond_post.posterior_observed.stack(sample=["chain", "draw"])
fig, ax = plt.subplots(figsize=(14, 4), dpi=144)
(
post_obs.sel(posterior_observed_dim_1=0).plot.line(
x="posterior_observed_dim_0", add_legend=False, ax=ax, color="0.5", alpha=0.05
)
)
plt.scatter(np.linspace(0,N,N), data_vec)
plt.show()