I am getting a weird error about an incompatible argument being sent to the scan function. I checked all the shapes of the core matrices using ss_mod.ssm["matrix"].eval().shape
and I think I’ve got the correct shapes everywhere. Any insights into what I am doing wrong would really be appreciated. Below is what I am doing:
import jax
import numpyro
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
from pymc_experimental.statespace.core.statespace import PyMCStateSpace
import pytensor.tensor as pt
import pymc as pm
from pymc_experimental.statespace.utils.constants import (
ALL_STATE_DIM,
ALL_STATE_AUX_DIM,
OBS_STATE_DIM,
SHOCK_DIM,
TIME_DIM
)
from pymc_experimental.statespace.models.utilities import make_default_coords
data = np.random.normal(size=1001)[:, np.newaxis]
exog = np.random.normal(size=1001)[:, np.newaxis]
class AutoRegressiveThree(PyMCStateSpace):
def __init__(self, exog: np.ndarray):
k_states = 4 # size of the state vector x
k_posdef = 1 # number of shocks (size of the state covariance matrix Q)
k_endog = 1 # number of observed states
self.k_exog = exog.shape[1]
self.n_obs = exog.shape[0]
super().__init__(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef)
def make_symbolic_graph(self):
exogenous_data = self.make_and_register_data("exogenous_data", shape=(None, 1))
beta_exog = self.make_and_register_variable("beta_exog", shape=(1,))
x0 = self.make_and_register_variable("x0", shape=(3,))
P0 = self.make_and_register_variable("P0", shape=(4, 4))
ar_params = self.make_and_register_variable("ar_params", shape=(3,))
sigma_x = self.make_and_register_variable("sigma_x", shape=(1,))
a = np.eye(4, k=-1)
a[3, 2] = 0.
a[3, 3] = 1.
self.ssm["transition", :, :] = a
self.ssm["transition", 0, :3] = ar_params
self.ssm["selection", 0, 0] = 1.
# design needs to have shape (1001, 1, 4) = (TIME_INDEX, K_POSDEF, K_STATES)
self.ssm["design"] = pt.concatenate(
(
pt.expand_dims(pt.repeat(pt.as_tensor([[1.,0.,0.]]), 1001, axis=0), 1),
pt.expand_dims(exogenous_data, 1)
), axis=2
)
self.ssm["initial_state", :] = pt.concatenate((x0, beta_exog), axis=0)
self.ssm["initial_state_cov", :, :] = P0
self.ssm["state_cov", :, :] = sigma_x
@property
def data_names(self) -> list[str]:
"""
Names of data variables expected by the model.
This does not include the observed data series, which is automatically handled by PyMC. This property only
needs to be implemented for models that expect exogenous data.
"""
return ["exogenous_data"]
@property
def data_info(self):
"""
Information about Data variables that need to be declared in the PyMC model block.
Returns a dictionary of data_name: dictionary of property-name:property description pairs. The return value is
used by the ``_print_data_requirements`` method, to print a message telling users how to define the necessary
data for the model. Each dictionary should have the following key-value pairs:
* key: "shape", value: a tuple of integers
* key: "dims", value: tuple of strings
"""
return {
"exogenous_data": {
"shape": (None, 1),
"dims": (TIME_DIM, "exog_states")
}
}
@property
def param_names(self):
return ["x0", "P0", "ar_params", "sigma_x", "beta_exog"]
@property
def state_names(self):
# Since the three states are lags of the data, i'll call them L1, L2 L3
return ["L1.data", "L2.data", "L3.data", "beta]
@property
def shock_names(self):
# There is one shock, called the "innovations" in the literature, so i'll go with that
return ["innovations"]
@property
def observed_states(self):
# Inspired, I know
return ["data"]
@property
def param_dims(self):
# This needs to map the 4 parameters to the names of the coords.
# There are special standardized names to use here. You can import them from
# pymc_experimental.statespace.utils.constants
# Not the best system. Something to improve on in the future.
return {
"x0": (ALL_STATE_DIM,),
"P0": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),
"ar_params": ("ar_lags",),
"sigma_x": (SHOCK_DIM,),
"beta_exog": ("exog_state",)
}
@property
def coords(self):
# This function puts coords on all those statespace matrices (x0, P0, c, d, T, Z, R, H, Q)
# and also on the different filter outputs so you don't have to worry about it. You only need to set
# the coords for the dims unique to your model. In this case, it's just "ar_lags"
coords = make_default_coords(self)
coords.update({"ar_lags": [1, 2, 3], "beta_exog_dims": np.arange(self.k_exog), "n_obs": np.arange(self.n_obs)})
return coords
@property
def param_info(self):
# This needs to return a dictionary where the keys are the parameter names, and the values are a
# dictionary. The value dictionary should have the following keys: "shape", "constraints", and "dims".
info = {
"x0": {
"shape": (self.k_states - 1,),
"constraints": "None",
},
"P0": {
"shape": (self.k_states, self.k_states),
"constraints": "Positive Semi-definite",
},
"sigma_x": {
"shape": (self.k_posdef,),
"constraints": "Positive",
},
"ar_params": {
"shape": (3,),
"constraints": "None",
},
"beta_exog": {
"shape": (self.k_exog,),
"constraints": "None",
},
}
# Lazy way to add the dims without making any typos
for name in self.param_names:
info[name]["dims"] = self.param_dims[name]
return info
with pm.Model(coords=ar3.coords) as pymc_mod:
exogenous_data = pm.Data("exogenous_data", exog, dims=['time', 'exog_state'])
beta_exog = pm.Normal("beta_exog", 0, 1, dims=["exog_state"])
x0 = pm.Deterministic(
"x0",
pt.as_tensor([0., 0., 0.]),
)
P0 = pm.Deterministic("P0", pt.eye(4) * 10., dims=["state", "state_aux"])
ar_params = pm.Normal("ar_params", sigma=0.25, dims=["ar_lags"])
sigma_x = pm.Exponential("sigma_x", 1, dims=["shock"])
ar3.build_statespace_graph(data=data, mode="JAX")
idata = pm.sample(nuts_sampler="numpyro")
And the full error message:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[573], line 17
13 sigma_x = pm.Exponential("sigma_x", 1, dims=["shock"])
16 ar3.build_statespace_graph(data=data, mode="JAX")
---> 17 idata = pm.sample(nuts_sampler="numpyro")
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/mcmc.py:741, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
736 raise ValueError(
737 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
738 )
740 with joined_blas_limiter():
--> 741 return _sample_external_nuts(
742 sampler=nuts_sampler,
743 draws=draws,
744 tune=tune,
745 chains=chains,
746 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
747 random_seed=random_seed,
748 initvals=initvals,
749 model=model,
750 var_names=var_names,
751 progressbar=progressbar,
752 idata_kwargs=idata_kwargs,
753 compute_convergence_checks=compute_convergence_checks,
754 nuts_sampler_kwargs=nuts_sampler_kwargs,
755 **kwargs,
756 )
758 if isinstance(step, list):
759 step = CompoundStep(step)
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/mcmc.py:364, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
361 elif sampler in ("numpyro", "blackjax"):
362 import pymc.sampling.jax as pymc_jax
--> 364 idata = pymc_jax.sample_jax_nuts(
365 draws=draws,
366 tune=tune,
367 chains=chains,
368 target_accept=target_accept,
369 random_seed=random_seed,
370 initvals=initvals,
371 model=model,
372 var_names=var_names,
373 progressbar=progressbar,
374 nuts_sampler=sampler,
375 idata_kwargs=idata_kwargs,
376 compute_convergence_checks=compute_convergence_checks,
377 **nuts_sampler_kwargs,
378 )
379 return idata
381 else:
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/jax.py:611, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
608 raise ValueError(f"{nuts_sampler=} not recognized")
610 tic1 = datetime.now()
--> 611 raw_mcmc_samples, sample_stats, library = sampler_fn(
612 model=model,
613 target_accept=target_accept,
614 tune=tune,
615 draws=draws,
616 chains=chains,
617 chain_method=chain_method,
618 progressbar=progressbar,
619 random_seed=random_seed,
620 initial_points=initial_points,
621 nuts_kwargs=nuts_kwargs,
622 )
623 tic2 = datetime.now()
625 if idata_kwargs is None:
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/jax.py:425, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
421 import numpyro
423 from numpyro.infer import MCMC, NUTS
--> 425 logp_fn = get_jaxified_logp(model, negative_logp=False)
427 nuts_kwargs.setdefault("adapt_step_size", True)
428 nuts_kwargs.setdefault("adapt_mass_matrix", True)
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/jax.py:151, in get_jaxified_logp(model, negative_logp)
149 if not negative_logp:
150 model_logp = -model_logp
--> 151 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
153 def logp_fn_wrap(x):
154 return logp_fn(*x)[0]
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/jax.py:126, in get_jaxified_graph(inputs, outputs)
121 def get_jaxified_graph(
122 inputs: list[TensorVariable] | None = None,
123 outputs: list[TensorVariable] | None = None,
124 ) -> list[TensorVariable]:
125 """Compile a PyTensor graph into an optimized JAX function."""
--> 126 graph = _replace_shared_variables(outputs) if outputs is not None else None
128 fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
129 # We need to add a Supervisor to the fgraph to be able to run the
130 # JAX sequential optimizer without warnings. We made sure there
131 # are no mutable input variables, so we only need to check for
132 # "destroyers". This should be automatically handled by PyTensor
133 # once https://github.com/aesara-devs/aesara/issues/637 is fixed.
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pymc/sampling/jax.py:117, in _replace_shared_variables(graph)
110 raise ValueError(
111 "Graph contains shared variables with default_update which cannot "
112 "be safely replaced."
113 )
115 replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables}
--> 117 new_graph = clone_replace(graph, replace=replacements)
118 return new_graph
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/graph/replace.py:85, in clone_replace(output, replace, **rebuild_kwds)
82 _, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
84 # TODO Explain why we call it twice ?!
---> 85 _, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
87 return outs
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:313, in rebuild_collect_shared(outputs, inputs, replace, updates, rebuild_strict, copy_inputs_over, no_default_updates, clone_inner_graphs)
311 for v in outputs:
312 if isinstance(v, Variable):
--> 313 cloned_v = clone_v_get_shared_updates(v, copy_inputs_over)
314 cloned_outputs.append(cloned_v)
315 elif isinstance(v, Out):
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
187 if owner not in clone_d:
188 for i in owner.inputs:
--> 189 clone_v_get_shared_updates(i, copy_inputs_over)
190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
187 if owner not in clone_d:
188 for i in owner.inputs:
--> 189 clone_v_get_shared_updates(i, copy_inputs_over)
190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)
[... skipping similar frames: rebuild_collect_shared.<locals>.clone_v_get_shared_updates at line 189 (2 times)]
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
187 if owner not in clone_d:
188 for i in owner.inputs:
--> 189 clone_v_get_shared_updates(i, copy_inputs_over)
190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:190, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
188 for i in owner.inputs:
189 clone_v_get_shared_updates(i, copy_inputs_over)
--> 190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)
197 elif isinstance(v, SharedVariable):
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/graph/basic.py:1296, in clone_node_and_cache(node, clone_d, clone_inner_graphs, **kwargs)
1292 new_op: Op | None = cast(Optional["Op"], clone_d.get(node.op))
1294 cloned_inputs: list[Variable] = [cast(Variable, clone_d[i]) for i in node.inputs]
-> 1296 new_node = node.clone_with_new_inputs(
1297 cloned_inputs,
1298 # Only clone inner-graph `Op`s when there isn't a cached clone (and
1299 # when `clone_inner_graphs` is enabled)
1300 clone_inner_graph=clone_inner_graphs if new_op is None else False,
1301 **kwargs,
1302 )
1304 if new_op:
1305 # If we didn't clone the inner-graph `Op` above, because
1306 # there was a cached version, set the cloned `Apply` to use
1307 # the cached clone `Op`
1308 new_node.op = new_op
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/graph/basic.py:297, in Apply.clone_with_new_inputs(self, inputs, strict, clone_inner_graph)
294 if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore
295 new_op = new_op.clone() # type: ignore
--> 297 new_node = new_op.make_node(*new_inputs)
298 new_node.tag = copy(self.tag).__update__(new_node.tag)
299 else:
File /opt/miniconda3/envs/pymc/lib/python3.12/site-packages/pytensor/scan/op.py:1196, in Scan.make_node(self, *inputs)
1194 new_inputs.append(outer_nonseq)
1195 if not outer_nonseq.type.in_same_class(inner_nonseq.type):
-> 1196 raise ValueError(
1197 f"Argument {outer_nonseq} given to the scan node is not"
1198 f" compatible with its corresponding loop function variable {inner_nonseq}"
1199 )
1201 for outer_nitsot in self.outer_nitsot(inputs):
1202 # For every nit_sot input we get as input a int/uint that
1203 # depicts the size in memory for that sequence. This feature is
1204 # used by truncated BPTT and by scan space optimization
1205 if (
1206 str(outer_nitsot.type.dtype) not in integer_dtypes
1207 or outer_nitsot.ndim != 0
1208 ):
ValueError: Argument [[ 0.69903 ... 32837271]] given to the scan node is not compatible with its corresponding loop function variable *4-<Matrix(float64, shape=(?, ?))>