Scan value error in State Space

I am getting a weird error about an incompatible argument being sent to the scan function. I checked all the shapes of the core matrices using ss_mod.ssm["matrix"].eval().shape and I think I’ve got the correct shapes everywhere. Any insights into what I am doing wrong would really be appreciated. Below is what I am doing:

import jax
import numpyro
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
from pymc_experimental.statespace.core.statespace import PyMCStateSpace
import pytensor.tensor as pt
import pymc as pm
from pymc_experimental.statespace.utils.constants import (
    ALL_STATE_DIM,
    ALL_STATE_AUX_DIM,
    OBS_STATE_DIM,
    SHOCK_DIM,
    TIME_DIM
)
from pymc_experimental.statespace.models.utilities import make_default_coords

data = np.random.normal(size=1001)[:, np.newaxis]
exog = np.random.normal(size=1001)[:, np.newaxis]

class AutoRegressiveThree(PyMCStateSpace):
    def __init__(self, exog: np.ndarray):
        k_states = 4  # size of the state vector x
        k_posdef = 1  # number of shocks (size of the state covariance matrix Q)
        k_endog = 1  # number of observed states
        self.k_exog = exog.shape[1]
        self.n_obs = exog.shape[0]

        super().__init__(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef)

    def make_symbolic_graph(self):
        exogenous_data = self.make_and_register_data("exogenous_data", shape=(None, 1))
        beta_exog = self.make_and_register_variable("beta_exog", shape=(1,))

        x0 = self.make_and_register_variable("x0", shape=(3,))
        P0 = self.make_and_register_variable("P0", shape=(4, 4))
        ar_params = self.make_and_register_variable("ar_params", shape=(3,))
        sigma_x = self.make_and_register_variable("sigma_x", shape=(1,))

        a = np.eye(4, k=-1)
        a[3, 2] = 0.
        a[3, 3] = 1.
        self.ssm["transition", :, :] = a
        self.ssm["transition", 0, :3] = ar_params

        self.ssm["selection", 0, 0] = 1.
        
        # design needs to have shape (1001, 1, 4) = (TIME_INDEX, K_POSDEF, K_STATES)
        self.ssm["design"] = pt.concatenate(
                                                (
                                                    pt.expand_dims(pt.repeat(pt.as_tensor([[1.,0.,0.]]), 1001, axis=0), 1),
                                                    pt.expand_dims(exogenous_data, 1)
                                                ), axis=2
                                            )

        self.ssm["initial_state", :] = pt.concatenate((x0, beta_exog), axis=0)
        self.ssm["initial_state_cov", :, :] = P0
        self.ssm["state_cov", :, :] = sigma_x

    @property
    def data_names(self) -> list[str]:
        """
        Names of data variables expected by the model.

        This does not include the observed data series, which is automatically handled by PyMC. This property only
        needs to be implemented for models that expect exogenous data.
        """
        return ["exogenous_data"]
    
    @property
    def data_info(self):
        """
        Information about Data variables that need to be declared in the PyMC model block.

        Returns a dictionary of data_name: dictionary of property-name:property description pairs. The return value is
        used by the ``_print_data_requirements`` method, to print a message telling users how to define the necessary
        data for the model. Each dictionary should have the following key-value pairs:
            * key: "shape", value: a tuple of integers
            * key: "dims", value: tuple of strings
        """
        return {
            "exogenous_data": {
                "shape": (None, 1),
                "dims": (TIME_DIM, "exog_states")
            }
        }

    @property
    def param_names(self):
        return ["x0", "P0", "ar_params", "sigma_x", "beta_exog"]

    @property
    def state_names(self):
        # Since the three states are lags of the data, i'll call them L1, L2 L3
        return ["L1.data", "L2.data", "L3.data", "beta]

    @property
    def shock_names(self):
        # There is one shock, called the "innovations" in the literature, so i'll go with that
        return ["innovations"]

    @property
    def observed_states(self):
        # Inspired, I know
        return ["data"]

    @property
    def param_dims(self):
        # This needs to map the 4 parameters to the names of the coords.
        # There are special standardized names to use here. You can import them from
        # pymc_experimental.statespace.utils.constants

        # Not the best system. Something to improve on in the future.
        return {
            "x0": (ALL_STATE_DIM,),
            "P0": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),
            "ar_params": ("ar_lags",),
            "sigma_x": (SHOCK_DIM,),
            "beta_exog": ("exog_state",)
        }

    @property
    def coords(self):
        # This function puts coords on all those statespace matrices (x0, P0, c, d, T, Z, R, H, Q)
        # and also on the different filter outputs so you don't have to worry about it. You only need to set
        # the coords for the dims unique to your model. In this case, it's just "ar_lags"
        coords = make_default_coords(self)
        coords.update({"ar_lags": [1, 2, 3], "beta_exog_dims": np.arange(self.k_exog), "n_obs": np.arange(self.n_obs)})

        return coords

    @property
    def param_info(self):
        # This needs to return a dictionary where the keys are the parameter names, and the values are a
        # dictionary. The value dictionary should have the following keys: "shape", "constraints", and "dims".

        info = {
            "x0": {
                "shape": (self.k_states - 1,),
                "constraints": "None",
            },
            "P0": {
                "shape": (self.k_states, self.k_states),
                "constraints": "Positive Semi-definite",
            },
            "sigma_x": {
                "shape": (self.k_posdef,),
                "constraints": "Positive",
            },
            "ar_params": {
                "shape": (3,),
                "constraints": "None",
            },
            "beta_exog": {
                "shape": (self.k_exog,),
                "constraints": "None",
            },
        }

        # Lazy way to add the dims without making any typos
        for name in self.param_names:
            info[name]["dims"] = self.param_dims[name]

        return info

with pm.Model(coords=ar3.coords) as pymc_mod:
    exogenous_data = pm.Data("exogenous_data", exog, dims=['time', 'exog_state'])
    beta_exog = pm.Normal("beta_exog", 0, 1, dims=["exog_state"])

    x0 = pm.Deterministic(
        "x0",
        pt.as_tensor([0., 0., 0.]),
    )
    P0 = pm.Deterministic("P0", pt.eye(4) * 10., dims=["state", "state_aux"])

    ar_params = pm.Normal("ar_params", sigma=0.25, dims=["ar_lags"])
    sigma_x = pm.Exponential("sigma_x", 1, dims=["shock"])
    

    ar3.build_statespace_graph(data=data, mode="JAX")
    idata = pm.sample(nuts_sampler="numpyro")

And the full error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[573], line 17
     13 sigma_x = pm.Exponential("sigma_x", 1, dims=["shock"])
     16 ar3.build_statespace_graph(data=data, mode="JAX")
---> 17 idata = pm.sample(nuts_sampler="numpyro")

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/mcmc.py:741, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    736         raise ValueError(
    737             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    738         )
    740     with joined_blas_limiter():
--> 741         return _sample_external_nuts(
    742             sampler=nuts_sampler,
    743             draws=draws,
    744             tune=tune,
    745             chains=chains,
    746             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    747             random_seed=random_seed,
    748             initvals=initvals,
    749             model=model,
    750             var_names=var_names,
    751             progressbar=progressbar,
    752             idata_kwargs=idata_kwargs,
    753             compute_convergence_checks=compute_convergence_checks,
    754             nuts_sampler_kwargs=nuts_sampler_kwargs,
    755             **kwargs,
    756         )
    758 if isinstance(step, list):
    759     step = CompoundStep(step)

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/mcmc.py:364, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    361 elif sampler in ("numpyro", "blackjax"):
    362     import pymc.sampling.jax as pymc_jax
--> 364     idata = pymc_jax.sample_jax_nuts(
    365         draws=draws,
    366         tune=tune,
    367         chains=chains,
    368         target_accept=target_accept,
    369         random_seed=random_seed,
    370         initvals=initvals,
    371         model=model,
    372         var_names=var_names,
    373         progressbar=progressbar,
    374         nuts_sampler=sampler,
    375         idata_kwargs=idata_kwargs,
    376         compute_convergence_checks=compute_convergence_checks,
    377         **nuts_sampler_kwargs,
    378     )
    379     return idata
    381 else:

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/jax.py:611, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    608     raise ValueError(f"{nuts_sampler=} not recognized")
    610 tic1 = datetime.now()
--> 611 raw_mcmc_samples, sample_stats, library = sampler_fn(
    612     model=model,
    613     target_accept=target_accept,
    614     tune=tune,
    615     draws=draws,
    616     chains=chains,
    617     chain_method=chain_method,
    618     progressbar=progressbar,
    619     random_seed=random_seed,
    620     initial_points=initial_points,
    621     nuts_kwargs=nuts_kwargs,
    622 )
    623 tic2 = datetime.now()
    625 if idata_kwargs is None:

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/jax.py:425, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
    421 import numpyro
    423 from numpyro.infer import MCMC, NUTS
--> 425 logp_fn = get_jaxified_logp(model, negative_logp=False)
    427 nuts_kwargs.setdefault("adapt_step_size", True)
    428 nuts_kwargs.setdefault("adapt_mass_matrix", True)

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/jax.py:151, in get_jaxified_logp(model, negative_logp)
    149 if not negative_logp:
    150     model_logp = -model_logp
--> 151 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
    153 def logp_fn_wrap(x):
    154     return logp_fn(*x)[0]

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/jax.py:126, in get_jaxified_graph(inputs, outputs)
    121 def get_jaxified_graph(
    122     inputs: list[TensorVariable] | None = None,
    123     outputs: list[TensorVariable] | None = None,
    124 ) -> list[TensorVariable]:
    125     """Compile a PyTensor graph into an optimized JAX function."""
--> 126     graph = _replace_shared_variables(outputs) if outputs is not None else None
    128     fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
    129     # We need to add a Supervisor to the fgraph to be able to run the
    130     # JAX sequential optimizer without warnings. We made sure there
    131     # are no mutable input variables, so we only need to check for
    132     # "destroyers". This should be automatically handled by PyTensor
    133     # once https://github.com/aesara-devs/aesara/issues/637 is fixed.

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/jax.py:117, in _replace_shared_variables(graph)
    110     raise ValueError(
    111         "Graph contains shared variables with default_update which cannot "
    112         "be safely replaced."
    113     )
    115 replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables}
--> 117 new_graph = clone_replace(graph, replace=replacements)
    118 return new_graph

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/graph/replace.py:85, in clone_replace(output, replace, **rebuild_kwds)
     82 _, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
     84 # TODO Explain why we call it twice ?!
---> 85 _, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
     87 return outs

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:313, in rebuild_collect_shared(outputs, inputs, replace, updates, rebuild_strict, copy_inputs_over, no_default_updates, clone_inner_graphs)
    311 for v in outputs:
    312     if isinstance(v, Variable):
--> 313         cloned_v = clone_v_get_shared_updates(v, copy_inputs_over)
    314         cloned_outputs.append(cloned_v)
    315     elif isinstance(v, Out):

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
    187 if owner not in clone_d:
    188     for i in owner.inputs:
--> 189         clone_v_get_shared_updates(i, copy_inputs_over)
    190     clone_node_and_cache(
    191         owner,
    192         clone_d,
    193         strict=rebuild_strict,
    194         clone_inner_graphs=clone_inner_graphs,
    195     )
    196 return clone_d.setdefault(v, v)

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
    187 if owner not in clone_d:
    188     for i in owner.inputs:
--> 189         clone_v_get_shared_updates(i, copy_inputs_over)
    190     clone_node_and_cache(
    191         owner,
    192         clone_d,
    193         strict=rebuild_strict,
    194         clone_inner_graphs=clone_inner_graphs,
    195     )
    196 return clone_d.setdefault(v, v)

    [... skipping similar frames: rebuild_collect_shared.<locals>.clone_v_get_shared_updates at line 189 (2 times)]

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
    187 if owner not in clone_d:
    188     for i in owner.inputs:
--> 189         clone_v_get_shared_updates(i, copy_inputs_over)
    190     clone_node_and_cache(
    191         owner,
    192         clone_d,
    193         strict=rebuild_strict,
    194         clone_inner_graphs=clone_inner_graphs,
    195     )
    196 return clone_d.setdefault(v, v)

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:190, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
    188         for i in owner.inputs:
    189             clone_v_get_shared_updates(i, copy_inputs_over)
--> 190         clone_node_and_cache(
    191             owner,
    192             clone_d,
    193             strict=rebuild_strict,
    194             clone_inner_graphs=clone_inner_graphs,
    195         )
    196     return clone_d.setdefault(v, v)
    197 elif isinstance(v, SharedVariable):

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/graph/basic.py:1296, in clone_node_and_cache(node, clone_d, clone_inner_graphs, **kwargs)
   1292 new_op: Op | None = cast(Optional["Op"], clone_d.get(node.op))
   1294 cloned_inputs: list[Variable] = [cast(Variable, clone_d[i]) for i in node.inputs]
-> 1296 new_node = node.clone_with_new_inputs(
   1297     cloned_inputs,
   1298     # Only clone inner-graph `Op`s when there isn't a cached clone (and
   1299     # when `clone_inner_graphs` is enabled)
   1300     clone_inner_graph=clone_inner_graphs if new_op is None else False,
   1301     **kwargs,
   1302 )
   1304 if new_op:
   1305     # If we didn't clone the inner-graph `Op` above, because
   1306     # there was a cached version, set the cloned `Apply` to use
   1307     # the cached clone `Op`
   1308     new_node.op = new_op

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/graph/basic.py:297, in Apply.clone_with_new_inputs(self, inputs, strict, clone_inner_graph)
    294     if isinstance(new_op, HasInnerGraph) and clone_inner_graph:  # type: ignore
    295         new_op = new_op.clone()  # type: ignore
--> 297     new_node = new_op.make_node(*new_inputs)
    298     new_node.tag = copy(self.tag).__update__(new_node.tag)
    299 else:

File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/scan/op.py:1196, in Scan.make_node(self, *inputs)
   1194     new_inputs.append(outer_nonseq)
   1195     if not outer_nonseq.type.in_same_class(inner_nonseq.type):
-> 1196         raise ValueError(
   1197             f"Argument {outer_nonseq} given to the scan node is not"
   1198             f" compatible with its corresponding loop function variable {inner_nonseq}"
   1199         )
   1201 for outer_nitsot in self.outer_nitsot(inputs):
   1202     # For every nit_sot input we get as input a int/uint that
   1203     # depicts the size in memory for that sequence. This feature is
   1204     # used by truncated BPTT and by scan space optimization
   1205     if (
   1206         str(outer_nitsot.type.dtype) not in integer_dtypes
   1207         or outer_nitsot.ndim != 0
   1208     ):

ValueError: Argument [[ 0.69903 ... 32837271]] given to the scan node is not compatible with its corresponding loop function variable *4-<Matrix(float64, shape=(?, ?))>

The first step to debugging an error like this is to comment out the pm.sample line, then look at pymc_mod.logp().dprint(). You’ll get the computation graph, and you need to look through it for *4-<Matrix(float64, shape=(?, ?))>. It will be in the scan subgraph, all the way towards the bottom. The point is to figure out which input that is, then check that the shapes of the incoming data/priors are correct.

Thank you @jessegrabowski. I was able to trace back that matrix to the exogenous data input, however, the shape of it is correct. I am wondering if maybe I should not be defining the time dimension on the design matrix manually? Is this something that the module handles in the background?

Hey @jessegrabowski, something odd is happening. When I build the state-space graph using the default mode and use the pymc sampler instead of numpyro the model samples just fine (very slowly though).

You have to embrace the symbolic paradigm of pytensor more aggressively. In a normal framework I agree you would need to provide the exogenous data to the class constructor, but here we’re not going to. Later, we’re going to make a symbolic data container and we can work with that. So the __init__ method of your class (which isn’t really an AR3, I might mention) should look like this:

    def __init__(self, k_exog=1):
        k_states = 3 + k_exog # size of the state vector x
        k_posdef = 1  # number of shocks (size of the state covariance matrix Q)
        k_endog = 1  # number of observed states
        self.k_exog = k_exog
        super().__init__(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef)

We really don’t even need to know k_exog, but we’ll pass it in for making dims. Scratch that, you actually do need to know it, otherwise the number of state variables is unknown and that’s not allowed :slight_smile:

For the rest, we build our computatonal graph using the symbolic shapes of the data variable created in make_symbolic_graph:

        exogenous_data = self.make_and_register_data("exogenous_data", shape=(None, None))
        beta_exog = self.make_and_register_variable("beta_exog", shape=(1,))
        
        n_obs = exogenous_data.shape[0]

Then, when you make the design matrix, repeat n_obs times, rather than hard-coded 1001:

        self.ssm["design"] = pt.concatenate(
                                                (
                                                    pt.expand_dims(pt.repeat(pt.as_tensor([[1.,0.,0.]]), n_obs, axis=0), 1),
                                                    pt.expand_dims(exogenous_data, 1)
                                                ), axis=2
                                            )

Finally, in the coords property, you have to remove the n_obs coordinate, because we don’t know it. It’s a duplicate of the time coordinate anyway, so nothing is lost.

After these changes, the model sampled fine for me.

Thank you, @jessegrabowski. Quick question why is the data’s shape (None, None)?:
exogenous_data = self.make_and_register_data("exogenous_data", shape=(None, None))
Shouldn’t it be shape=(None, self.k_exog)?

Assuming it is shape=(None, self.k_exog) I think something is going on with numpyro. I still get the scan value error after making the changes you suggested. It works using the built-in pymc sampler, though.

Yes, you are right about the shapes.

Can you make sure everything is up to date? In particular, switch from pymc_experimental to pymc_extras, but also make sure pymc and pytensor are the most recent.

1 Like

Thank you, Jesse. I created a clean environment with pymc-extras instead of pymc-experimental and now it samples with mode="JAX" and nuts_sampler="numpyro". Again, thank you for all your help and patience.