# Using aesara.scan to compute the linear predictor involving a B-spline matrix

Hi there,

This isn’t my first question involving the use of aesara.scan (sorry!). I haven’t grasped how to write them just yet.

I want to build upon model 6 in this fantastic blog post by @jhrcook, with an end goal of incorporating spatial dependence between the various time sequences.

For now, I want to extend the model to many time sequences (but few observations, hopefully, this saves me from a computational point of view). The model is:

# simulated observed data
E = np.random.uniform(low=2.0, high=8.0, size=(45108,))
y = np.random.poisson(6, size=(45108,))

# dimension information
coords = {"spline_dim": np.arange(11),
"aux_dim": np.arange(1),
"num_sequence": np.arange(2148),
"n_obs":np.arange(45108)}

# model
with pm.Model(coords=coords) as spline_model:
sd = pm.Gamma.dist(2, 0.5, shape=B_dim)
chol, corr, stds = pm.LKJCholeskyCov(
"chol", eta=2, n=B_dim, sd_dist=sd, compute_corr=True
)
cov = pm.Deterministic("cov", chol.dot(chol.T))

mu_w = pm.Normal("mu_w", 0, 1, dims=("spline_dim", "aux_dim"))
delta_w = pm.Normal("delta_w", 0, 1, dims=("spline_dim", "num_sequence"))
w = pm.Deterministic("w", mu_w + at.dot(chol, delta_w))

_mu = []
for i in range(n_sequences):
_mu.append(pm.math.dot(B[sequence_index == i, :], w[:, i]).reshape((-1, 1)))

a = pm.Normal("a", 0, 10)

mu = pm.Deterministic("mu", a + at.vertical_stack(*_mu).squeeze() + at.log(E), dims="n_obs")
y_ = pm.Poisson("y_", at.exp(mu), observed=y, dims="n_obs")


Here is the plate diagram for the model:

Currently, the model doesn’t compile. By not compiling, I mean that it doesn’t necessarily error, but after leaving it running for ~30 minutes, the model is yet to begin the sampling procedure.

The think the reason for this is that

  _mu = []
for i in range(num_sequences):
_mu.append(pm.math.dot(B[sequence_index == i, :], w[:, i]).reshape((-1, 1)))


should be replaced with an aesara.scan. Currently, this loop is over 2148 different time sequences, and indexing into matrix B\in\mathbb{R}^{45108\times 11} to return B_i\in\mathbb{R}^{21\times 11} and then dot producing with \mathbb{w}_i\in\mathbb{R}^{11} to return \boldsymbol{\mu}_i\in\mathbb{R}^{11}. If a single scan statement could return \boldsymbol{\mu}\in\mathbb{R}^{45108}, then maybe the model will compile much faster?

Thanks,
Conor

You definitely don’t want to loop over creation of aesara variables, so your instinct to try a scan is right.

You can re-create your loop by using at.arange(n_sequences) as the sequence input to scan. If you do this, you can treat, B, w, and sequence_index as non-sequences. It ends up looking very close to what you have:

result, updates = aesara.scan(lambda i, A, x, idx: A[idx== i, :].dot(x[:, i]),
sequences=at.arange(n_sequences),
non_sequences=[B, w, sequence_index])


Remember that the order of inputs to scan goes sequences, then recursive outputs, then non-sequences. That’s why the inputs to the lambda need to be in that order. I guess you could omit the non-sequences all together and just use global variables inside the lambda, but I think it looks sloppy.

result will be a list of n_sequences tensors with shape (seq_len, ), so you can finagle that to get your mu vector. I guess you would call at.stack on to get a single vector of length n_obs, if I understand right.

2 Likes

Also, since in your loop you are doing a matrix multiplication that does not depend on output from the previous step, you should be able to rewrite it with a tensor operation instead.

2 Likes

Yes, you can also block-diagonalize both B and w and do a matrix multiplication. B would transformed into a (T * n_sequences, B_dim * n_sequences) matrix, while w would be (B_dim * n_sequences, n_sequences). The structure is such that there will be only a single non-zero entry in each column after B @ w, so you can sum away the n_sequences dimension on the result to get the answer.

I found this way was actually slower than the scan for large B matrices, unless one uses sparse matrices. Here’s how I transformed everything:

from scipy import sparse
from aesara import sparse as asm

B_prime = sparse.block_diag([B[i * T:(i+1) * T, :] for i in range(n_sequences)], format='csr')
B_at = asm.as_sparse_or_tensor_variable(B_prime)

w_prime_at = at.linalg.kron(at.eye(n_sequences), at.ones((B_dim, 1))) * at.concatenate([w] * n_sequences)

mu = asm.dot(B_at, w_prime_at).sum(axis=1)


This gives the same answer as the scan. I didn’t do any careful benchmarking, but it seemed OK in my .eval() tests. I also have no idea if pm.Data is compatible with aesara sparse matrices, which might be important.

1 Like

Hi @jessegrabowski, thanks for taking the time to write a great reply! I learn a lot about aesara from your answers.

Plugging the sparse matrices code that you provided into the model and fitting the model to a subset of data (e.g., 20 time sequences), everything works great! But when running for the model on the full data set (e.g., ~2000 time sequences), I get the following error message. Is it something to do with memory constraints? Have you seen something like this before?

CompileError                              Traceback (most recent call last)
1241 # no-recycling is done at each VM.__call__ So there is
1242 # no need to cause duplicate c code by passing
1243 # no_recycling here.
1244 thunks.append(
-> 1245     node.op.make_thunk(node, storage_map, compute_map, [], impl=impl)
1246 )
1247 linker_make_thunk_time[node] = time.time() - thunk_start

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/op.py:131, in COp.make_thunk(self, node, storage_map, compute_map, no_recycling, impl)
130 try:
--> 131     return self.make_c_thunk(node, storage_map, compute_map, no_recycling)
132 except (NotImplementedError, MethodNotDefined):
133     # We requested the c code, so don't catch the error.

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/op.py:96, in COp.make_c_thunk(self, node, storage_map, compute_map, no_recycling)
95         raise NotImplementedError("float16")
---> 96 outputs = cl.make_thunk(
97     input_storage=node_input_storage, output_storage=node_output_storage
98 )
99 thunk, node_input_filters, node_output_filters = outputs

-> 1198 cthunk, module, in_storage, out_storage, error_storage = self.__compile__(
1199     input_storage, output_storage, storage_map
1200 )

1132 output_storage = tuple(output_storage)
-> 1133 thunk, module = self.cthunk_factory(
1134     error_storage,
1135     input_storage,
1136     output_storage,
1137     storage_map,
1138 )
1139 return (
1140     thunk,
1141     module,
(...)
1150     error_storage,
1151 )

1628         node.op.prepare_node(node, storage_map, None, "c")
-> 1629     module = get_module_cache().module_from_key(key=key, lnk=self)
1631 vars = self.inputs + self.outputs + self.orphans

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/cmodule.py:1223, in ModuleCache.module_from_key(self, key, lnk)
1222 location = dlimport_workdir(self.dirname)
-> 1223 module = lnk.compile_cmodule(location)
1224 name = module.__file__

1537     _logger.debug(f"LOCATION {location}")
-> 1538     module = c_compiler.compile_str(
1539         module_name=mod.code_hash,
1540         src_code=src_code,
1541         location=location,
1543         lib_dirs=self.lib_dirs(),
1544         libs=libs,
1545         preargs=preargs,
1546     )
1547 except Exception as e:

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/cmodule.py:2636, in GCC_compiler.compile_str(module_name, src_code, location, include_dirs, lib_dirs, libs, preargs, py_module, hide_symbols)
2632     # We replace '\n' by '. ' in the error message because when Python
2633     # prints the exception, having '\n' in the text makes it more
2635     # compile_stderr = compile_stderr.replace("\n", ". ")
-> 2636     raise CompileError(
2637         f"Compilation failed (return status={status}):\n{' '.join(cmd)}\n{compile_stderr}"
2638     )
2639 elif config.cmodule__compilation_warning and compile_stderr:
2640     # Print errors just below the command line.

CompileError: Compilation failed (return status=1):
/Users/conor/miniconda3/envs/pymc_non_dev/bin/clang++ -dynamiclib -g -O3 -fno-math-errno -Wno-unused-label -Wno-unused-variable -Wno-write-strings -Wno-c++11-narrowing -fno-exceptions -fno-unwind-tables -fno-asynchronous-unwind-tables -DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION -fPIC -undefined dynamic_lookup -I/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/numpy/core/include -I/Users/conor/miniconda3/envs/pymc_non_dev/include/python3.10 -I/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/c_code -L/Users/conor/miniconda3/envs/pymc_non_dev/lib -fvisibility=hidden -o /Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/m5d0ee6bbad3216d425b6f97b97d8e0d98406727e7556194d542dc09b367c4d92.so /Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/mod.cpp
/Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/mod.cpp:53956:32: fatal error: bracket nesting level exceeded maximum of 256
if (!PyErr_Occurred()) {
^
/Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/mod.cpp:53956:32: note: use -fbracket-depth=N to increase maximum nesting level
1 error generated.

During handling of the above exception, another exception occurred:

CompileError                              Traceback (most recent call last)
/Users/conor/Library/CloudStorage/OneDrive-QueenslandUniversityofTechnology/Documents/AustralianCancerAtlasPyMC/data-wrangling.ipynb Cell 27' in <cell line: 1>()
1 with spline_temporal_model:
----> 2     spline_fit = pm.fit(n=100000)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/variational/inference.py:744, in fit(n, method, model, random_seed, start, inf_kwargs, **kwargs)
742 else:
743     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 744 return inference.fit(n, **kwargs)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/variational/inference.py:138, in Inference.fit(self, n, score, callbacks, progressbar, **kwargs)
136     callbacks = []
137 score = self._maybe_score(score)
--> 138 step_func = self.objective.step_function(score=score, **kwargs)
139 if progressbar:
140     progress = progress_bar(range(n), display=progressbar)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/configparser.py:47, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
44 @wraps(f)
45 def res(*args, **kwargs):
46     with self:
---> 47         return f(*args, **kwargs)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/variational/opvi.py:367, in ObjectiveFunction.step_function(self, obj_n_mc, tf_n_mc, obj_optimizer, test_optimizer, more_obj_params, more_tf_params, more_updates, more_replacements, total_grad_norm_constraint, score, fn_kwargs)
356     obj_n_mc=obj_n_mc,
357     tf_n_mc=tf_n_mc,
(...)
365 )
366 if score:
368 else:

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/aesaraf.py:1034, in compile_pymc(inputs, outputs, random_seed, mode, **kwargs)
1032 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
-> 1034 aesara_function = aesara.function(
1035     inputs,
1036     outputs,
1038     mode=mode,
1039     **kwargs,
1040 )
1041 return aesara_function

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/__init__.py:317, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
311     fn = orig_function(
312         inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
313     )
314 else:
315     # note: pfunc will also call orig_function -- orig_function is
316     #      a choke point that all compilation must pass through
--> 317     fn = pfunc(
318         params=inputs,
319         outputs=outputs,
320         mode=mode,
322         givens=givens,
324         accept_inplace=accept_inplace,
325         name=name,
326         rebuild_strict=rebuild_strict,
327         allow_input_downcast=allow_input_downcast,
328         on_unused_input=on_unused_input,
329         profile=profile,
330         output_keys=output_keys,
331     )
332 return fn

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/pfunc.py:374, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
360     profile = ProfileStats(message=profile)
362 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
363     params,
364     outputs,
(...)
371     fgraph=fgraph,
372 )
--> 374 return orig_function(
375     inputs,
376     cloned_outputs,
377     mode,
378     accept_inplace=accept_inplace,
379     name=name,
380     profile=profile,
381     on_unused_input=on_unused_input,
382     output_keys=output_keys,
383     fgraph=fgraph,
384 )

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/types.py:1763, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1751     m = Maker(
1752         inputs,
1753         outputs,
(...)
1760         fgraph=fgraph,
1761     )
1762     with config.change_flags(compute_test_value="off"):
-> 1763         fn = m.create(defaults)
1764 finally:
1765     t2 = time.time()

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/types.py:1656, in FunctionMaker.create(self, input_storage, trustme, storage_map)
1655 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1656     _fn, _i, _o = self.linker.make_thunk(
1657         input_storage=input_storage_lists, storage_map=storage_map
1658     )

247 def make_thunk(
248     self,
249     input_storage: Optional["InputStorageType"] = None,
(...)
252     **kwargs,
253 ) -> Tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 254     return self.make_all(
255         input_storage=input_storage,
256         output_storage=output_storage,
257         storage_map=storage_map,
258     )[:3]

1252             thunks[-1].lazy = False
1253     except Exception:
-> 1254         raise_with_op(fgraph, node)
1256 t1 = time.time()
1258 if self.profile:

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
529     warnings.warn(
530         f"{exc_type} error does not allow us to add an extra error message"
531     )
532     # Some exception need extra parameter in inputs. So forget the
533     # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)

1240 thunk_start = time.time()
1241 # no-recycling is done at each VM.__call__ So there is
1242 # no need to cause duplicate c code by passing
1243 # no_recycling here.
1244 thunks.append(
-> 1245     node.op.make_thunk(node, storage_map, compute_map, [], impl=impl)
1246 )
1247 linker_make_thunk_time[node] = time.time() - thunk_start
1248 if not hasattr(thunks[-1], "lazy"):
1249     # We don't want all ops maker to think about lazy Ops.
1250     # So if they didn't specify that its lazy or not, it isn't.
1251     # If this member isn't present, it will crash later.

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/op.py:131, in COp.make_thunk(self, node, storage_map, compute_map, no_recycling, impl)
127 self.prepare_node(
128     node, storage_map=storage_map, compute_map=compute_map, impl="c"
129 )
130 try:
--> 131     return self.make_c_thunk(node, storage_map, compute_map, no_recycling)
132 except (NotImplementedError, MethodNotDefined):
133     # We requested the c code, so don't catch the error.
134     if impl == "c":

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/op.py:96, in COp.make_c_thunk(self, node, storage_map, compute_map, no_recycling)
94         print(f"Disabling C code for {self} due to unsupported float16")
95         raise NotImplementedError("float16")
---> 96 outputs = cl.make_thunk(
97     input_storage=node_input_storage, output_storage=node_output_storage
98 )
99 thunk, node_input_filters, node_output_filters = outputs
101 @is_cthunk_wrapper_type
102 def rval():

1170 """
1171 Compiles this linker's fgraph and returns a function to perform the
1172 computations, as well as lists of storage cells for both the inputs
(...)
1195   first_output = ostor[0].data
1196 """
-> 1198 cthunk, module, in_storage, out_storage, error_storage = self.__compile__(
1199     input_storage, output_storage, storage_map
1200 )
1203 res.nodes = self.node_order

1131 input_storage = tuple(input_storage)
1132 output_storage = tuple(output_storage)
-> 1133 thunk, module = self.cthunk_factory(
1134     error_storage,
1135     input_storage,
1136     output_storage,
1137     storage_map,
1138 )
1139 return (
1140     thunk,
1141     module,
(...)
1150     error_storage,
1151 )

1627     for node in self.node_order:
1628         node.op.prepare_node(node, storage_map, None, "c")
-> 1629     module = get_module_cache().module_from_key(key=key, lnk=self)
1631 vars = self.inputs + self.outputs + self.orphans
1632 # List of indices that should be ignored when passing the arguments
1633 # (basically, everything that the previous call to uniq eliminated)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/cmodule.py:1223, in ModuleCache.module_from_key(self, key, lnk)
1221 try:
1222     location = dlimport_workdir(self.dirname)
-> 1223     module = lnk.compile_cmodule(location)
1224     name = module.__file__
1225     assert name.startswith(location)

1536 try:
1537     _logger.debug(f"LOCATION {location}")
-> 1538     module = c_compiler.compile_str(
1539         module_name=mod.code_hash,
1540         src_code=src_code,
1541         location=location,
1543         lib_dirs=self.lib_dirs(),
1544         libs=libs,
1545         preargs=preargs,
1546     )
1547 except Exception as e:
1548     e.args += (str(self.fgraph),)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/cmodule.py:2636, in GCC_compiler.compile_str(module_name, src_code, location, include_dirs, lib_dirs, libs, preargs, py_module, hide_symbols)
2628                 print(
2629                     "Check if package python-dev or python-devel is installed."
2630                 )
2632     # We replace '\n' by '. ' in the error message because when Python
2633     # prints the exception, having '\n' in the text makes it more
2635     # compile_stderr = compile_stderr.replace("\n", ". ")
-> 2636     raise CompileError(
2637         f"Compilation failed (return status={status}):\n{' '.join(cmd)}\n{compile_stderr}"
2638     )
2639 elif config.cmodule__compilation_warning and compile_stderr:
2640     # Print errors just below the command line.
2641     print(compile_stderr)

CompileError: Compilation failed (return status=1):
/Users/conor/miniconda3/envs/pymc_non_dev/bin/clang++ -dynamiclib -g -O3 -fno-math-errno -Wno-unused-label -Wno-unused-variable -Wno-write-strings -Wno-c++11-narrowing -fno-exceptions -fno-unwind-tables -fno-asynchronous-unwind-tables -DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION -fPIC -undefined dynamic_lookup -I/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/numpy/core/include -I/Users/conor/miniconda3/envs/pymc_non_dev/include/python3.10 -I/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/c_code -L/Users/conor/miniconda3/envs/pymc_non_dev/lib -fvisibility=hidden -o /Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/m5d0ee6bbad3216d425b6f97b97d8e0d98406727e7556194d542dc09b367c4d92.so /Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/mod.cpp
/Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/mod.cpp:53956:32: fatal error: bracket nesting level exceeded maximum of 256
if (!PyErr_Occurred()) {
^
/Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/mod.cpp:53956:32: note: use -fbracket-depth=N to increase maximum nesting level
1 error generated.

Apply node that caused the error: Split{2148}(Elemwise{Mul}[(0, 0)].0, TensorConstant{0}, TensorConstant{(2148,) of 11})
Toposort index: 135
Inputs types: [TensorType(float64, (None, None)), TensorType(int8, ()), TensorType(int64, (2148,))]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
term = access_term_cache(node)[idx]
File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1058, in access_term_cache
File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1058, in <listcomp>
term = access_term_cache(node)[idx]
File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1058, in access_term_cache
File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1058, in <listcomp>
term = access_term_cache(node)[idx]
File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1213, in access_term_cache

HINT: Use a linker other than the C linker to print the inputs' shapes and strides.
HINT: Use the Aesara flag exception_verbosity=high for a debug print-out and storage map footprint of this Apply node.


Could be. I’ve never seen anyone use the sparse module of aesara in a PyMC model to be honest, so I think you’re on the “here be dragons” part of the map. It looks like it’s the C compiler that’s complaining, so that’s not a good sign. @ricardoV94 might be able to say more. You might be able to add the -fbracket-depth=N argument to the compiler flags in your .aesararc file, if you were hell bent on getting this approach to work; that seems to be the problem. But I know literally nothing about these compiler settings.

Did you try out the scan as well? That might be a bit more gentle (and better studied as a component of PyMC models).

1 Like

Yes, I have tried but haven’t had any luck. I don’t understand what aesara.scan returns within a pm.Model context.
Here is how I added the scan into the model:

    result, _ = aesara.scan(lambda i, A, x, idx: A[np.array(idx==i), :].dot(x[:, i]),
sequences=at.arange(num_sequence),
non_sequences=[B, w, sequence_index])
mu_stack = at.stack(result[-1])


Here is the error message that I got:

ValueError                                Traceback (most recent call last)
File scan_perform.pyx:418, in aesara.scan.scan_perform.perform()

ValueError: Shape mismatch: A.shape[1] != x.shape[0]

During handling of the above exception, another exception occurred:

InnerFunctionError                        Traceback (most recent call last)
File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/scan/op.py:1606, in Scan.make_thunk.<locals>.p(node, inputs, outputs)
1605 try:
-> 1606     t_fn, n_steps = scan_perform_ext.perform(
1607         self.info.n_shared_outs,
1608         self.info.n_mit_mot_outs,
1609         self.info.n_seqs,
1610         self.info.n_mit_mot,
1611         self.info.n_mit_sot,
1612         self.info.n_sit_sot,
1613         self.info.n_nit_sot,
1614         self.info.as_while,
1615         cython_mintaps,
1616         cython_pos,
1617         cython_store_steps,
1618         self.info.mit_mot_in_slices
1619         + self.info.mit_sot_in_slices
1620         + self.info.sit_sot_in_slices,
1621         tap_array_len,
1622         cython_vector_seqs,
1623         cython_vector_outs,
1624         self.info.mit_mot_out_slices,
1625         cython_mitmots_preallocated,
1626         mit_mot_out_to_tap_idx,
1627         cython_outs_is_tensor,
1628         inner_input_storage,
1629         inner_output_storage,
1630         cython_destroy_map,
1631         inputs,
1632         outputs,
1633         outer_output_dtypes,
1634         outer_output_ndims,
1635         self.fn.vm,
1636     )
1637 except InnerFunctionError as exc:

File scan_perform.pyx:420, in aesara.scan.scan_perform.perform()

InnerFunctionError: (ValueError('Shape mismatch: A.shape[1] != x.shape[0]'), <traceback object at 0x28d615640>)

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/types.py:975, in Function.__call__(self, *args, **kwargs)
973 try:
974     outputs = (
--> 975         self.vm()
976         if output_subset is None
977         else self.vm(output_subset=output_subset)
978     )
979 except Exception:

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/scan/op.py:1678, in Scan.make_thunk.<locals>.rval(p, i, o, n, allow_gc)
1675 def rval(
1676     p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc
1677 ):
-> 1678     r = p(n, [x[0] for x in i], o)
1679     for o in node.outputs:

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/scan/op.py:1645, in Scan.make_thunk.<locals>.p(node, inputs, outputs)
1642 if hasattr(self.fn.vm, "position_of_error") and hasattr(
1643     self.fn.vm, "thunks"
1644 ):
-> 1645     raise_with_op(
1646         self.fn.maker.fgraph,
1647         self.fn.vm.nodes[self.fn.vm.position_of_error],
1648         self.fn.vm.thunks[self.fn.vm.position_of_error],
1649         exc_info=(exc_type, exc_value, exc_trace),
1650     )
1651 else:

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
532     # Some exception need extra parameter in inputs. So forget the
533     # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)

File scan_perform.pyx:418, in aesara.scan.scan_perform.perform()

ValueError: Shape mismatch: A.shape[1] != x.shape[0]
Apply node that caused the error: CGemv{inplace}(AllocEmpty{dtype='float64'}.0, TensorConstant{1.0}, InplaceDimShuffle{x,0}.0, *2-<TensorType(float64, (None,))>, TensorConstant{0.0})
Toposort index: 4
Inputs types: [TensorType(float64, (1,)), TensorType(float64, ()), TensorType(float64, (1, None)), TensorType(float64, (None,)), TensorType(float64, ())]
Inputs shapes: [(1,), (), (1, 6), (0, 42, 6), ()]
Inputs strides: [(8,), (), (96, 16), (2016, 48, 8), ()]
Inputs values: [array([0.]), array(1.), 'not shown', array([], shape=(0, 42, 6), dtype=float64), array(0.)]
Outputs clients: [[InplaceDimShuffle{}(CGemv{inplace}.0)]]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag exception_verbosity=high for a debug print-out and storage map footprint of this Apply node.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
/Users/conor/Library/CloudStorage/OneDrive-QueenslandUniversityofTechnology/Documents/AustralianCancerAtlasPyMC/data-wrangling.ipynb Cell 33' in <cell line: 1>()
1 with spline_temporal_model_ALT:
----> 2     fit_model_ALT = pm.fit()

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/variational/inference.py:744, in fit(n, method, model, random_seed, start, inf_kwargs, **kwargs)
742 else:
743     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 744 return inference.fit(n, **kwargs)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/variational/inference.py:144, in Inference.fit(self, n, score, callbacks, progressbar, **kwargs)
142     progress = range(n)
143 if score:
--> 144     state = self._iterate_with_loss(0, n, step_func, progress, callbacks)
145 else:
146     state = self._iterate_without_loss(0, n, step_func, progress, callbacks)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/variational/inference.py:204, in Inference._iterate_with_loss(self, s, n, step_func, progress, callbacks)
202 try:
203     for i in progress:
--> 204         e = step_func()
205         if np.isnan(e):
206             scores = scores[:i]

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/types.py:988, in Function.__call__(self, *args, **kwargs)
986     if hasattr(self.vm, "thunks"):
987         thunk = self.vm.thunks[self.vm.position_of_error]
--> 988     raise_with_op(
989         self.maker.fgraph,
990         node=self.vm.nodes[self.vm.position_of_error],
991         thunk=thunk,
992         storage_map=getattr(self.vm, "storage_map", None),
993     )
994 else:
995     # old-style linkers raise their own exceptions
996     raise

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
529     warnings.warn(
530         f"{exc_type} error does not allow us to add an extra error message"
531     )
532     # Some exception need extra parameter in inputs. So forget the
533     # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/types.py:975, in Function.__call__(self, *args, **kwargs)
972 t0_fn = time.time()
973 try:
974     outputs = (
--> 975         self.vm()
976         if output_subset is None
977         else self.vm(output_subset=output_subset)
978     )
979 except Exception:
980     restore_defaults()

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/scan/op.py:1678, in Scan.make_thunk.<locals>.rval(p, i, o, n, allow_gc)
1675 def rval(
1676     p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc
1677 ):
-> 1678     r = p(n, [x[0] for x in i], o)
1679     for o in node.outputs:
1680         compute_map[o][0] = True

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/scan/op.py:1645, in Scan.make_thunk.<locals>.p(node, inputs, outputs)
1640 exc_trace = exc.args[1]
1642 if hasattr(self.fn.vm, "position_of_error") and hasattr(
1643     self.fn.vm, "thunks"
1644 ):
-> 1645     raise_with_op(
1646         self.fn.maker.fgraph,
1647         self.fn.vm.nodes[self.fn.vm.position_of_error],
1648         self.fn.vm.thunks[self.fn.vm.position_of_error],
1649         exc_info=(exc_type, exc_value, exc_trace),
1650     )
1651 else:
1652     raise exc_value.with_traceback(exc_trace)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
529     warnings.warn(
530         f"{exc_type} error does not allow us to add an extra error message"
531     )
532     # Some exception need extra parameter in inputs. So forget the
533     # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)

File scan_perform.pyx:418, in aesara.scan.scan_perform.perform()

ValueError: Shape mismatch: A.shape[1] != x.shape[0]
Apply node that caused the error: CGemv{inplace}(AllocEmpty{dtype='float64'}.0, TensorConstant{1.0}, InplaceDimShuffle{x,0}.0, *2-<TensorType(float64, (None,))>, TensorConstant{0.0})
Toposort index: 4
Inputs types: [TensorType(float64, (1,)), TensorType(float64, ()), TensorType(float64, (1, None)), TensorType(float64, (None,)), TensorType(float64, ())]
Inputs shapes: [(1,), (), (1, 6), (0, 42, 6), ()]
Inputs strides: [(8,), (), (96, 16), (2016, 48, 8), ()]
Inputs values: [array([0.]), array(1.), 'not shown', array([], shape=(0, 42, 6), dtype=float64), array(0.)]
Outputs clients: [[InplaceDimShuffle{}(CGemv{inplace}.0)]]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag exception_verbosity=high for a debug print-out and storage map footprint of this Apply node.
Apply node that caused the error: for{cpu,scan_fn}(TensorConstant{2}, TensorConstant{[0 1]}, TensorConstant{1}, Elemwise{Add}[(0, 1)].0, AdvancedSubtensor.0)
Toposort index: 84
Inputs types: [TensorType(int64, ()), TensorType(int64, (2,)), TensorType(int64, ()), TensorType(float64, (None, None)), TensorType(float64, (None,))]
Inputs shapes: [(), (2,), (), (6, 2), (0, 42, 6)]
Inputs strides: [(), (8,), (), (16, 8), (2016, 48, 8)]
Inputs values: [array(2), array([0, 1]), array(1), 'not shown', array([], shape=(0, 42, 6), dtype=float64)]
Outputs clients: [[Subtensor{int64}(for{cpu,scan_fn}.0, ScalarConstant{0})]]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag exception_verbosity=high for a debug print-out and storage map footprint of this Apply node.


It’s saying that it can’t do the matrix multiplication A[idx == i, :] @ x[:, i] because the shapes don’t match up. Do all of your time series have the same length?

I’ve been running into the same error whilst trying to use aesara.scan. Did you find the issue? (Was it just a shape mismatch as @jessegrabowski suggested?)

Thanks! David