Possible bug: sample_prior_predictive with JAX fails for time-varying statespace models

In the code below, sample_prior_predictive(compile_kwargs={“mode”: “JAX”}) fails for the time-varying model. (Simple one-state model with exponential decay and a source term which may or may not change with time.)

Also, sample_unconditional_prior consistently complains about …

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: Subtensor{i, :, j}(state_cov{}, 0, -1)

…but the code seems to run fine despite that. Is this something to be concerned about?

import jax
jax.config.update("jax_platform_name", "cpu")

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import arviz as az
import pytensor
import pytensor.tensor as pt 
import pymc as pm
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc_extras.statespace.models.utilities import make_default_coords
from pymc_extras.statespace.core import PyMCStateSpace

# # this fails (Array slice indices must have static start/stop/step to be used with NumPy indexing syntax, among other things)
# time_varying = True
# sample_with_JAX = True

# this runs, but sample_unconditional_prior throws "Rewrite failure due to: constant_folding" errors
time_varying = True
sample_with_JAX = False

# # this runs, but throws "Rewrite failure due to: constant_folding" errors
# time_varying = False
# sample_with_JAX = True

# # this runs, but throws "Rewrite failure due to: constant_folding" errors
# time_varying = False
# sample_with_JAX = False

class SimpleModel(PyMCStateSpace):
   def __init__(self, n_timesteps=100):
      k_states = 1
      k_posdef = 0
      k_endog = 1
      self.n_timesteps = n_timesteps
      super().__init__(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef)

   def make_symbolic_graph(self):
      # sets the matrix & vector values
      x0 = self.make_and_register_variable('x0', shape=())
      P0 = self.make_and_register_variable('P0', shape=())
      self.ssm["initial_state",:] = x0
      self.ssm['initial_state_cov',:] = P0
      
      self.ssm["design", 0, 0] = 1
      self.ssm["transition", 0, 0] = 0.95
      
      if time_varying:
         source = self.make_and_register_variable("source", shape=(self.n_timesteps,))
         S_I = pt.zeros((self.n_timesteps, self.k_states))
         S_I = S_I[:, 0].set(source)
         self.ssm["state_intercept"] = S_I
      else:
         source = self.make_and_register_variable("source", shape=())
         self.ssm["state_intercept", 0] = source
        
   @property
   def param_names(self):
      params = ["x0", "P0", "source"]
      return params
    
   @property
   def state_names(self):
      return ['x1_latent']
    
   @property
   def shock_names(self):
      return []
    
   @property
   def observed_states(self):
      return ['x1']
    
   @property
   def coords(self):
      return make_default_coords(self)



N = 100
if time_varying:
   data_vec = np.linspace(0.5, 2, N) + np.random.normal(0, 0.1, size=N)
else:
   data_vec = np.random.normal(0.5, 0.1, size=N)
data = pd.DataFrame(data_vec, index=pd.RangeIndex(N), columns=['x1'])

sm = SimpleModel(n_timesteps=N)

with pm.Model() as pymc_mod:
   x0 = pm.Normal("x0", mu=0, sigma=1)
   P0 = pm.HalfNormal("P0", sigma=0.01)

   if time_varying:
      mu = pm.Normal('mu', mu=0, sigma=1, shape=(2,))
      source = pm.Deterministic('source', pt.linspace(mu[0], mu[1], N))
   else:
      mu = pm.Normal('mu', mu=0, sigma=1)
      source = pm.Deterministic('source', mu)

   sm.build_statespace_graph(data=data, mode="JAX")

with pymc_mod:
   if sample_with_JAX:
      prior = pm.sample_prior_predictive(compile_kwargs={"mode": "JAX"})
   else:
      prior = pm.sample_prior_predictive()


uncond_prior = sm.sample_unconditional_prior(prior, steps=100)
prior_obs = uncond_prior.prior_observed.stack(sample=["chain", "draw"])

fig, ax = plt.subplots(figsize=(14, 4), dpi=144)
(
   prior_obs.sel(prior_observed_dim_1=0).plot.line(
      x="prior_observed_dim_0", add_legend=False, ax=ax, color="0.5", alpha=0.05
   )
)
plt.scatter(np.linspace(0,N,N), data_vec)
plt.show()

with pymc_mod:
   idata = pm.sample(
      nuts_sampler="nutpie",
      nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"},
      cores=8,
   )

az.plot_trace(idata, var_names=["x0", 'mu'])
plt.show()

uncond_post = sm.sample_unconditional_posterior(idata, steps=100)
post_obs = uncond_post.posterior_observed.stack(sample=["chain", "draw"])

fig, ax = plt.subplots(figsize=(14, 4), dpi=144)
(
   post_obs.sel(posterior_observed_dim_1=0).plot.line(
      x="posterior_observed_dim_0", add_legend=False, ax=ax, color="0.5", alpha=0.05
   )
)
plt.scatter(np.linspace(0,N,N), data_vec)
plt.show()

Hey @JSR

The constant folding error is annoying but it won’t ruin your program. It means that pytensor had some kind of error when trying to do an optimization, so it left it as it found it.

But the reason you’re seeing this is because your model has no sources of uncertainty – either innovations or measurement error. This is called a stochastic singularity (in the macroeconomics literature at least), and the kalman gain matrix is not identified and you get numerical blow-ups when you try to invert it during log likelihood computation. I’m surprised the model fit at all. Ideally, you would get some kind of informative error in this case.

If you add either a state_cov or obs_cov with some non-zero elements, it should clear up. Now, I tried it, and it does clear up if you do state_cov, but not obs_cov. I would consider this a bug. If you could open an issue on the pymc-extras repo we can get it fixed.

PS: I now recommend mode=NUMBA over mode=JAX with pytensor >= 2.37 and pymc-extras >= 0.8.0. We did a ton of work to enable numba caching, so compile times are way way down, and it samples faster than JAX on CPU in my experience.

PPS: we also did a refactor of custom statespace models to make the bookkeeping of those properties less onerous. You can see the updated examples here. The changes are backwards compatible, so you can keep the properties that return a list of strings with no issue. But in your case, you can delete the coords property and keep everything else the same – the default coords are added by the base class by default.