Using aesara.scan to compute the linear predictor involving a B-spline matrix

Hi there,

This isn’t my first question involving the use of aesara.scan (sorry!). I haven’t grasped how to write them just yet.

I want to build upon model 6 in this fantastic blog post by @jhrcook, with an end goal of incorporating spatial dependence between the various time sequences.

For now, I want to extend the model to many time sequences (but few observations, hopefully, this saves me from a computational point of view). The model is:

# simulated observed data 
E = np.random.uniform(low=2.0, high=8.0, size=(45108,))
y = np.random.poisson(6, size=(45108,))

# dimension information 
coords = {"spline_dim": np.arange(11), 
          "aux_dim": np.arange(1), 
          "num_sequence": np.arange(2148), 
          "n_obs":np.arange(45108)}

# model 
with pm.Model(coords=coords) as spline_model:
    sd = pm.Gamma.dist(2, 0.5, shape=B_dim)
    chol, corr, stds = pm.LKJCholeskyCov(
        "chol", eta=2, n=B_dim, sd_dist=sd, compute_corr=True
    )
    cov = pm.Deterministic("cov", chol.dot(chol.T))

    mu_w = pm.Normal("mu_w", 0, 1, dims=("spline_dim", "aux_dim"))
    delta_w = pm.Normal("delta_w", 0, 1, dims=("spline_dim", "num_sequence"))
    w = pm.Deterministic("w", mu_w + at.dot(chol, delta_w))

  _mu = []
  for i in range(n_sequences):
      _mu.append(pm.math.dot(B[sequence_index == i, :], w[:, i]).reshape((-1, 1)))
 
    a = pm.Normal("a", 0, 10)

    mu = pm.Deterministic("mu", a + at.vertical_stack(*_mu).squeeze() + at.log(E), dims="n_obs")
    y_ = pm.Poisson("y_", at.exp(mu), observed=y, dims="n_obs")

Here is the plate diagram for the model:

Currently, the model doesn’t compile. By not compiling, I mean that it doesn’t necessarily error, but after leaving it running for ~30 minutes, the model is yet to begin the sampling procedure.

The think the reason for this is that

  _mu = []
  for i in range(num_sequences):
      _mu.append(pm.math.dot(B[sequence_index == i, :], w[:, i]).reshape((-1, 1)))

should be replaced with an aesara.scan. Currently, this loop is over 2148 different time sequences, and indexing into matrix B\in\mathbb{R}^{45108\times 11} to return B_i\in\mathbb{R}^{21\times 11} and then dot producing with \mathbb{w}_i\in\mathbb{R}^{11} to return \boldsymbol{\mu}_i\in\mathbb{R}^{11}. If a single scan statement could return \boldsymbol{\mu}\in\mathbb{R}^{45108}, then maybe the model will compile much faster?

Thanks,
Conor

You definitely don’t want to loop over creation of aesara variables, so your instinct to try a scan is right.

You can re-create your loop by using at.arange(n_sequences) as the sequence input to scan. If you do this, you can treat, B, w, and sequence_index as non-sequences. It ends up looking very close to what you have:

result, updates = aesara.scan(lambda i, A, x, idx: A[idx== i, :].dot(x[:, i]),
                              sequences=at.arange(n_sequences),
                              non_sequences=[B, w, sequence_index])

Remember that the order of inputs to scan goes sequences, then recursive outputs, then non-sequences. That’s why the inputs to the lambda need to be in that order. I guess you could omit the non-sequences all together and just use global variables inside the lambda, but I think it looks sloppy.

result will be a list of n_sequences tensors with shape (seq_len, ), so you can finagle that to get your mu vector. I guess you would call at.stack on to get a single vector of length n_obs, if I understand right.

2 Likes

Also, since in your loop you are doing a matrix multiplication that does not depend on output from the previous step, you should be able to rewrite it with a tensor operation instead.

2 Likes

Yes, you can also block-diagonalize both B and w and do a matrix multiplication. B would transformed into a (T * n_sequences, B_dim * n_sequences) matrix, while w would be (B_dim * n_sequences, n_sequences). The structure is such that there will be only a single non-zero entry in each column after B @ w, so you can sum away the n_sequences dimension on the result to get the answer.

I found this way was actually slower than the scan for large B matrices, unless one uses sparse matrices. Here’s how I transformed everything:

from scipy import sparse
from aesara import sparse as asm

B_prime = sparse.block_diag([B[i * T:(i+1) * T, :] for i in range(n_sequences)], format='csr')
B_at = asm.as_sparse_or_tensor_variable(B_prime)

w_prime_at = at.linalg.kron(at.eye(n_sequences), at.ones((B_dim, 1))) * at.concatenate([w] * n_sequences)

mu = asm.dot(B_at, w_prime_at).sum(axis=1)

This gives the same answer as the scan. I didn’t do any careful benchmarking, but it seemed OK in my .eval() tests. I also have no idea if pm.Data is compatible with aesara sparse matrices, which might be important.

1 Like

Hi @jessegrabowski, thanks for taking the time to write a great reply! I learn a lot about aesara from your answers.

Plugging the sparse matrices code that you provided into the model and fitting the model to a subset of data (e.g., 20 time sequences), everything works great! But when running for the model on the full data set (e.g., ~2000 time sequences), I get the following error message. Is it something to do with memory constraints? Have you seen something like this before?

CompileError                              Traceback (most recent call last)
File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/vm.py:1245, in VMLinker.make_all(self, profiler, input_storage, output_storage, storage_map)
   1241 # no-recycling is done at each VM.__call__ So there is
   1242 # no need to cause duplicate c code by passing
   1243 # no_recycling here.
   1244 thunks.append(
-> 1245     node.op.make_thunk(node, storage_map, compute_map, [], impl=impl)
   1246 )
   1247 linker_make_thunk_time[node] = time.time() - thunk_start

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/op.py:131, in COp.make_thunk(self, node, storage_map, compute_map, no_recycling, impl)
    130 try:
--> 131     return self.make_c_thunk(node, storage_map, compute_map, no_recycling)
    132 except (NotImplementedError, MethodNotDefined):
    133     # We requested the c code, so don't catch the error.

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/op.py:96, in COp.make_c_thunk(self, node, storage_map, compute_map, no_recycling)
     95         raise NotImplementedError("float16")
---> 96 outputs = cl.make_thunk(
     97     input_storage=node_input_storage, output_storage=node_output_storage
     98 )
     99 thunk, node_input_filters, node_output_filters = outputs

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/basic.py:1198, in CLinker.make_thunk(self, input_storage, output_storage, storage_map)
   1197 init_tasks, tasks = self.get_init_tasks()
-> 1198 cthunk, module, in_storage, out_storage, error_storage = self.__compile__(
   1199     input_storage, output_storage, storage_map
   1200 )
   1202 res = _CThunk(cthunk, init_tasks, tasks, error_storage, module)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/basic.py:1133, in CLinker.__compile__(self, input_storage, output_storage, storage_map)
   1132 output_storage = tuple(output_storage)
-> 1133 thunk, module = self.cthunk_factory(
   1134     error_storage,
   1135     input_storage,
   1136     output_storage,
   1137     storage_map,
   1138 )
   1139 return (
   1140     thunk,
   1141     module,
   (...)
   1150     error_storage,
   1151 )

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/basic.py:1629, in CLinker.cthunk_factory(self, error_storage, in_storage, out_storage, storage_map)
   1628         node.op.prepare_node(node, storage_map, None, "c")
-> 1629     module = get_module_cache().module_from_key(key=key, lnk=self)
   1631 vars = self.inputs + self.outputs + self.orphans

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/cmodule.py:1223, in ModuleCache.module_from_key(self, key, lnk)
   1222 location = dlimport_workdir(self.dirname)
-> 1223 module = lnk.compile_cmodule(location)
   1224 name = module.__file__

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/basic.py:1538, in CLinker.compile_cmodule(self, location)
   1537     _logger.debug(f"LOCATION {location}")
-> 1538     module = c_compiler.compile_str(
   1539         module_name=mod.code_hash,
   1540         src_code=src_code,
   1541         location=location,
   1542         include_dirs=self.header_dirs(),
   1543         lib_dirs=self.lib_dirs(),
   1544         libs=libs,
   1545         preargs=preargs,
   1546     )
   1547 except Exception as e:

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/cmodule.py:2636, in GCC_compiler.compile_str(module_name, src_code, location, include_dirs, lib_dirs, libs, preargs, py_module, hide_symbols)
   2632     # We replace '\n' by '. ' in the error message because when Python
   2633     # prints the exception, having '\n' in the text makes it more
   2634     # difficult to read.
   2635     # compile_stderr = compile_stderr.replace("\n", ". ")
-> 2636     raise CompileError(
   2637         f"Compilation failed (return status={status}):\n{' '.join(cmd)}\n{compile_stderr}"
   2638     )
   2639 elif config.cmodule__compilation_warning and compile_stderr:
   2640     # Print errors just below the command line.

CompileError: Compilation failed (return status=1):
/Users/conor/miniconda3/envs/pymc_non_dev/bin/clang++ -dynamiclib -g -O3 -fno-math-errno -Wno-unused-label -Wno-unused-variable -Wno-write-strings -Wno-c++11-narrowing -fno-exceptions -fno-unwind-tables -fno-asynchronous-unwind-tables -DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION -fPIC -undefined dynamic_lookup -I/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/numpy/core/include -I/Users/conor/miniconda3/envs/pymc_non_dev/include/python3.10 -I/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/c_code -L/Users/conor/miniconda3/envs/pymc_non_dev/lib -fvisibility=hidden -o /Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/m5d0ee6bbad3216d425b6f97b97d8e0d98406727e7556194d542dc09b367c4d92.so /Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/mod.cpp
/Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/mod.cpp:53956:32: fatal error: bracket nesting level exceeded maximum of 256
        if (!PyErr_Occurred()) {
                               ^
/Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/mod.cpp:53956:32: note: use -fbracket-depth=N to increase maximum nesting level
1 error generated.


During handling of the above exception, another exception occurred:

CompileError                              Traceback (most recent call last)
/Users/conor/Library/CloudStorage/OneDrive-QueenslandUniversityofTechnology/Documents/AustralianCancerAtlasPyMC/data-wrangling.ipynb Cell 27' in <cell line: 1>()
      1 with spline_temporal_model:
----> 2     spline_fit = pm.fit(n=100000)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/variational/inference.py:744, in fit(n, method, model, random_seed, start, inf_kwargs, **kwargs)
    742 else:
    743     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 744 return inference.fit(n, **kwargs)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/variational/inference.py:138, in Inference.fit(self, n, score, callbacks, progressbar, **kwargs)
    136     callbacks = []
    137 score = self._maybe_score(score)
--> 138 step_func = self.objective.step_function(score=score, **kwargs)
    139 if progressbar:
    140     progress = progress_bar(range(n), display=progressbar)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/configparser.py:47, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     44 @wraps(f)
     45 def res(*args, **kwargs):
     46     with self:
---> 47         return f(*args, **kwargs)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/variational/opvi.py:367, in ObjectiveFunction.step_function(self, obj_n_mc, tf_n_mc, obj_optimizer, test_optimizer, more_obj_params, more_tf_params, more_updates, more_replacements, total_grad_norm_constraint, score, fn_kwargs)
    355 updates = self.updates(
    356     obj_n_mc=obj_n_mc,
    357     tf_n_mc=tf_n_mc,
   (...)
    364     total_grad_norm_constraint=total_grad_norm_constraint,
    365 )
    366 if score:
--> 367     step_fn = compile_pymc([], updates.loss, updates=updates, **fn_kwargs)
    368 else:
    369     step_fn = compile_pymc([], [], updates=updates, **fn_kwargs)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/aesaraf.py:1034, in compile_pymc(inputs, outputs, random_seed, mode, **kwargs)
   1032 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
   1033 mode = Mode(linker=mode.linker, optimizer=opt_qry)
-> 1034 aesara_function = aesara.function(
   1035     inputs,
   1036     outputs,
   1037     updates={**rng_updates, **kwargs.pop("updates", {})},
   1038     mode=mode,
   1039     **kwargs,
   1040 )
   1041 return aesara_function

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/__init__.py:317, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
    311     fn = orig_function(
    312         inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
    313     )
    314 else:
    315     # note: pfunc will also call orig_function -- orig_function is
    316     #      a choke point that all compilation must pass through
--> 317     fn = pfunc(
    318         params=inputs,
    319         outputs=outputs,
    320         mode=mode,
    321         updates=updates,
    322         givens=givens,
    323         no_default_updates=no_default_updates,
    324         accept_inplace=accept_inplace,
    325         name=name,
    326         rebuild_strict=rebuild_strict,
    327         allow_input_downcast=allow_input_downcast,
    328         on_unused_input=on_unused_input,
    329         profile=profile,
    330         output_keys=output_keys,
    331     )
    332 return fn

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/pfunc.py:374, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
    360     profile = ProfileStats(message=profile)
    362 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    363     params,
    364     outputs,
   (...)
    371     fgraph=fgraph,
    372 )
--> 374 return orig_function(
    375     inputs,
    376     cloned_outputs,
    377     mode,
    378     accept_inplace=accept_inplace,
    379     name=name,
    380     profile=profile,
    381     on_unused_input=on_unused_input,
    382     output_keys=output_keys,
    383     fgraph=fgraph,
    384 )

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/types.py:1763, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
   1751     m = Maker(
   1752         inputs,
   1753         outputs,
   (...)
   1760         fgraph=fgraph,
   1761     )
   1762     with config.change_flags(compute_test_value="off"):
-> 1763         fn = m.create(defaults)
   1764 finally:
   1765     t2 = time.time()

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/types.py:1656, in FunctionMaker.create(self, input_storage, trustme, storage_map)
   1653 start_import_time = aesara.link.c.cmodule.import_time
   1655 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1656     _fn, _i, _o = self.linker.make_thunk(
   1657         input_storage=input_storage_lists, storage_map=storage_map
   1658     )
   1660 end_linker = time.time()
   1662 linker_time = end_linker - start_linker

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/basic.py:254, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
    247 def make_thunk(
    248     self,
    249     input_storage: Optional["InputStorageType"] = None,
   (...)
    252     **kwargs,
    253 ) -> Tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 254     return self.make_all(
    255         input_storage=input_storage,
    256         output_storage=output_storage,
    257         storage_map=storage_map,
    258     )[:3]

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/vm.py:1254, in VMLinker.make_all(self, profiler, input_storage, output_storage, storage_map)
   1252             thunks[-1].lazy = False
   1253     except Exception:
-> 1254         raise_with_op(fgraph, node)
   1256 t1 = time.time()
   1258 if self.profile:

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    529     warnings.warn(
    530         f"{exc_type} error does not allow us to add an extra error message"
    531     )
    532     # Some exception need extra parameter in inputs. So forget the
    533     # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/vm.py:1245, in VMLinker.make_all(self, profiler, input_storage, output_storage, storage_map)
   1240 thunk_start = time.time()
   1241 # no-recycling is done at each VM.__call__ So there is
   1242 # no need to cause duplicate c code by passing
   1243 # no_recycling here.
   1244 thunks.append(
-> 1245     node.op.make_thunk(node, storage_map, compute_map, [], impl=impl)
   1246 )
   1247 linker_make_thunk_time[node] = time.time() - thunk_start
   1248 if not hasattr(thunks[-1], "lazy"):
   1249     # We don't want all ops maker to think about lazy Ops.
   1250     # So if they didn't specify that its lazy or not, it isn't.
   1251     # If this member isn't present, it will crash later.

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/op.py:131, in COp.make_thunk(self, node, storage_map, compute_map, no_recycling, impl)
    127 self.prepare_node(
    128     node, storage_map=storage_map, compute_map=compute_map, impl="c"
    129 )
    130 try:
--> 131     return self.make_c_thunk(node, storage_map, compute_map, no_recycling)
    132 except (NotImplementedError, MethodNotDefined):
    133     # We requested the c code, so don't catch the error.
    134     if impl == "c":

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/op.py:96, in COp.make_c_thunk(self, node, storage_map, compute_map, no_recycling)
     94         print(f"Disabling C code for {self} due to unsupported float16")
     95         raise NotImplementedError("float16")
---> 96 outputs = cl.make_thunk(
     97     input_storage=node_input_storage, output_storage=node_output_storage
     98 )
     99 thunk, node_input_filters, node_output_filters = outputs
    101 @is_cthunk_wrapper_type
    102 def rval():

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/basic.py:1198, in CLinker.make_thunk(self, input_storage, output_storage, storage_map)
   1170 """
   1171 Compiles this linker's fgraph and returns a function to perform the
   1172 computations, as well as lists of storage cells for both the inputs
   (...)
   1195   first_output = ostor[0].data
   1196 """
   1197 init_tasks, tasks = self.get_init_tasks()
-> 1198 cthunk, module, in_storage, out_storage, error_storage = self.__compile__(
   1199     input_storage, output_storage, storage_map
   1200 )
   1202 res = _CThunk(cthunk, init_tasks, tasks, error_storage, module)
   1203 res.nodes = self.node_order

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/basic.py:1133, in CLinker.__compile__(self, input_storage, output_storage, storage_map)
   1131 input_storage = tuple(input_storage)
   1132 output_storage = tuple(output_storage)
-> 1133 thunk, module = self.cthunk_factory(
   1134     error_storage,
   1135     input_storage,
   1136     output_storage,
   1137     storage_map,
   1138 )
   1139 return (
   1140     thunk,
   1141     module,
   (...)
   1150     error_storage,
   1151 )

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/basic.py:1629, in CLinker.cthunk_factory(self, error_storage, in_storage, out_storage, storage_map)
   1627     for node in self.node_order:
   1628         node.op.prepare_node(node, storage_map, None, "c")
-> 1629     module = get_module_cache().module_from_key(key=key, lnk=self)
   1631 vars = self.inputs + self.outputs + self.orphans
   1632 # List of indices that should be ignored when passing the arguments
   1633 # (basically, everything that the previous call to uniq eliminated)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/cmodule.py:1223, in ModuleCache.module_from_key(self, key, lnk)
   1221 try:
   1222     location = dlimport_workdir(self.dirname)
-> 1223     module = lnk.compile_cmodule(location)
   1224     name = module.__file__
   1225     assert name.startswith(location)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/basic.py:1538, in CLinker.compile_cmodule(self, location)
   1536 try:
   1537     _logger.debug(f"LOCATION {location}")
-> 1538     module = c_compiler.compile_str(
   1539         module_name=mod.code_hash,
   1540         src_code=src_code,
   1541         location=location,
   1542         include_dirs=self.header_dirs(),
   1543         lib_dirs=self.lib_dirs(),
   1544         libs=libs,
   1545         preargs=preargs,
   1546     )
   1547 except Exception as e:
   1548     e.args += (str(self.fgraph),)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/cmodule.py:2636, in GCC_compiler.compile_str(module_name, src_code, location, include_dirs, lib_dirs, libs, preargs, py_module, hide_symbols)
   2628                 print(
   2629                     "Check if package python-dev or python-devel is installed."
   2630                 )
   2632     # We replace '\n' by '. ' in the error message because when Python
   2633     # prints the exception, having '\n' in the text makes it more
   2634     # difficult to read.
   2635     # compile_stderr = compile_stderr.replace("\n", ". ")
-> 2636     raise CompileError(
   2637         f"Compilation failed (return status={status}):\n{' '.join(cmd)}\n{compile_stderr}"
   2638     )
   2639 elif config.cmodule__compilation_warning and compile_stderr:
   2640     # Print errors just below the command line.
   2641     print(compile_stderr)

CompileError: Compilation failed (return status=1):
/Users/conor/miniconda3/envs/pymc_non_dev/bin/clang++ -dynamiclib -g -O3 -fno-math-errno -Wno-unused-label -Wno-unused-variable -Wno-write-strings -Wno-c++11-narrowing -fno-exceptions -fno-unwind-tables -fno-asynchronous-unwind-tables -DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION -fPIC -undefined dynamic_lookup -I/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/numpy/core/include -I/Users/conor/miniconda3/envs/pymc_non_dev/include/python3.10 -I/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/c/c_code -L/Users/conor/miniconda3/envs/pymc_non_dev/lib -fvisibility=hidden -o /Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/m5d0ee6bbad3216d425b6f97b97d8e0d98406727e7556194d542dc09b367c4d92.so /Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/mod.cpp
/Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/mod.cpp:53956:32: fatal error: bracket nesting level exceeded maximum of 256
        if (!PyErr_Occurred()) {
                               ^
/Users/conor/.aesara/compiledir_macOS-12.4-arm64-arm-64bit-arm-3.10.5-64/tmprd_gvep5/mod.cpp:53956:32: note: use -fbracket-depth=N to increase maximum nesting level
1 error generated.

Apply node that caused the error: Split{2148}(Elemwise{Mul}[(0, 0)].0, TensorConstant{0}, TensorConstant{(2148,) of 11})
Toposort index: 135
Inputs types: [TensorType(float64, (None, None)), TensorType(int8, ()), TensorType(int64, (2148,))]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1387, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1058, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1058, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1387, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1058, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1058, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1387, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/Users/conor/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/gradient.py", line 1213, in access_term_cache
    input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)

HINT: Use a linker other than the C linker to print the inputs' shapes and strides.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

Could be. I’ve never seen anyone use the sparse module of aesara in a PyMC model to be honest, so I think you’re on the “here be dragons” part of the map. It looks like it’s the C compiler that’s complaining, so that’s not a good sign. @ricardoV94 might be able to say more. You might be able to add the -fbracket-depth=N argument to the compiler flags in your .aesararc file, if you were hell bent on getting this approach to work; that seems to be the problem. But I know literally nothing about these compiler settings.

Did you try out the scan as well? That might be a bit more gentle (and better studied as a component of PyMC models).

1 Like

Yes, I have tried but haven’t had any luck. I don’t understand what aesara.scan returns within a pm.Model context.
Here is how I added the scan into the model:

    result, _ = aesara.scan(lambda i, A, x, idx: A[np.array(idx==i), :].dot(x[:, i]), 
                         sequences=at.arange(num_sequence), 
                         non_sequences=[B, w, sequence_index])
    mu_stack = at.stack(result[-1]) 

Here is the error message that I got:

ValueError                                Traceback (most recent call last)
File scan_perform.pyx:418, in aesara.scan.scan_perform.perform()

ValueError: Shape mismatch: A.shape[1] != x.shape[0]

During handling of the above exception, another exception occurred:

InnerFunctionError                        Traceback (most recent call last)
File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/scan/op.py:1606, in Scan.make_thunk.<locals>.p(node, inputs, outputs)
   1605 try:
-> 1606     t_fn, n_steps = scan_perform_ext.perform(
   1607         self.info.n_shared_outs,
   1608         self.info.n_mit_mot_outs,
   1609         self.info.n_seqs,
   1610         self.info.n_mit_mot,
   1611         self.info.n_mit_sot,
   1612         self.info.n_sit_sot,
   1613         self.info.n_nit_sot,
   1614         self.info.as_while,
   1615         cython_mintaps,
   1616         cython_pos,
   1617         cython_store_steps,
   1618         self.info.mit_mot_in_slices
   1619         + self.info.mit_sot_in_slices
   1620         + self.info.sit_sot_in_slices,
   1621         tap_array_len,
   1622         cython_vector_seqs,
   1623         cython_vector_outs,
   1624         self.info.mit_mot_out_slices,
   1625         cython_mitmots_preallocated,
   1626         mit_mot_out_to_tap_idx,
   1627         cython_outs_is_tensor,
   1628         inner_input_storage,
   1629         inner_output_storage,
   1630         cython_destroy_map,
   1631         inputs,
   1632         outputs,
   1633         outer_output_dtypes,
   1634         outer_output_ndims,
   1635         self.fn.vm,
   1636     )
   1637 except InnerFunctionError as exc:

File scan_perform.pyx:420, in aesara.scan.scan_perform.perform()

InnerFunctionError: (ValueError('Shape mismatch: A.shape[1] != x.shape[0]'), <traceback object at 0x28d615640>)

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/types.py:975, in Function.__call__(self, *args, **kwargs)
    973 try:
    974     outputs = (
--> 975         self.vm()
    976         if output_subset is None
    977         else self.vm(output_subset=output_subset)
    978     )
    979 except Exception:

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/scan/op.py:1678, in Scan.make_thunk.<locals>.rval(p, i, o, n, allow_gc)
   1675 def rval(
   1676     p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc
   1677 ):
-> 1678     r = p(n, [x[0] for x in i], o)
   1679     for o in node.outputs:

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/scan/op.py:1645, in Scan.make_thunk.<locals>.p(node, inputs, outputs)
   1642 if hasattr(self.fn.vm, "position_of_error") and hasattr(
   1643     self.fn.vm, "thunks"
   1644 ):
-> 1645     raise_with_op(
   1646         self.fn.maker.fgraph,
   1647         self.fn.vm.nodes[self.fn.vm.position_of_error],
   1648         self.fn.vm.thunks[self.fn.vm.position_of_error],
   1649         exc_info=(exc_type, exc_value, exc_trace),
   1650     )
   1651 else:

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    532     # Some exception need extra parameter in inputs. So forget the
    533     # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)

File scan_perform.pyx:418, in aesara.scan.scan_perform.perform()

ValueError: Shape mismatch: A.shape[1] != x.shape[0]
Apply node that caused the error: CGemv{inplace}(AllocEmpty{dtype='float64'}.0, TensorConstant{1.0}, InplaceDimShuffle{x,0}.0, *2-<TensorType(float64, (None,))>, TensorConstant{0.0})
Toposort index: 4
Inputs types: [TensorType(float64, (1,)), TensorType(float64, ()), TensorType(float64, (1, None)), TensorType(float64, (None,)), TensorType(float64, ())]
Inputs shapes: [(1,), (), (1, 6), (0, 42, 6), ()]
Inputs strides: [(8,), (), (96, 16), (2016, 48, 8), ()]
Inputs values: [array([0.]), array(1.), 'not shown', array([], shape=(0, 42, 6), dtype=float64), array(0.)]
Outputs clients: [[InplaceDimShuffle{}(CGemv{inplace}.0)]]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
/Users/conor/Library/CloudStorage/OneDrive-QueenslandUniversityofTechnology/Documents/AustralianCancerAtlasPyMC/data-wrangling.ipynb Cell 33' in <cell line: 1>()
      1 with spline_temporal_model_ALT:
----> 2     fit_model_ALT = pm.fit()

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/variational/inference.py:744, in fit(n, method, model, random_seed, start, inf_kwargs, **kwargs)
    742 else:
    743     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 744 return inference.fit(n, **kwargs)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/variational/inference.py:144, in Inference.fit(self, n, score, callbacks, progressbar, **kwargs)
    142     progress = range(n)
    143 if score:
--> 144     state = self._iterate_with_loss(0, n, step_func, progress, callbacks)
    145 else:
    146     state = self._iterate_without_loss(0, n, step_func, progress, callbacks)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/pymc/variational/inference.py:204, in Inference._iterate_with_loss(self, s, n, step_func, progress, callbacks)
    202 try:
    203     for i in progress:
--> 204         e = step_func()
    205         if np.isnan(e):
    206             scores = scores[:i]

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/types.py:988, in Function.__call__(self, *args, **kwargs)
    986     if hasattr(self.vm, "thunks"):
    987         thunk = self.vm.thunks[self.vm.position_of_error]
--> 988     raise_with_op(
    989         self.maker.fgraph,
    990         node=self.vm.nodes[self.vm.position_of_error],
    991         thunk=thunk,
    992         storage_map=getattr(self.vm, "storage_map", None),
    993     )
    994 else:
    995     # old-style linkers raise their own exceptions
    996     raise

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    529     warnings.warn(
    530         f"{exc_type} error does not allow us to add an extra error message"
    531     )
    532     # Some exception need extra parameter in inputs. So forget the
    533     # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/compile/function/types.py:975, in Function.__call__(self, *args, **kwargs)
    972 t0_fn = time.time()
    973 try:
    974     outputs = (
--> 975         self.vm()
    976         if output_subset is None
    977         else self.vm(output_subset=output_subset)
    978     )
    979 except Exception:
    980     restore_defaults()

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/scan/op.py:1678, in Scan.make_thunk.<locals>.rval(p, i, o, n, allow_gc)
   1675 def rval(
   1676     p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc
   1677 ):
-> 1678     r = p(n, [x[0] for x in i], o)
   1679     for o in node.outputs:
   1680         compute_map[o][0] = True

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/scan/op.py:1645, in Scan.make_thunk.<locals>.p(node, inputs, outputs)
   1640 exc_trace = exc.args[1]
   1642 if hasattr(self.fn.vm, "position_of_error") and hasattr(
   1643     self.fn.vm, "thunks"
   1644 ):
-> 1645     raise_with_op(
   1646         self.fn.maker.fgraph,
   1647         self.fn.vm.nodes[self.fn.vm.position_of_error],
   1648         self.fn.vm.thunks[self.fn.vm.position_of_error],
   1649         exc_info=(exc_type, exc_value, exc_trace),
   1650     )
   1651 else:
   1652     raise exc_value.with_traceback(exc_trace)

File ~/miniconda3/envs/pymc_non_dev/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    529     warnings.warn(
    530         f"{exc_type} error does not allow us to add an extra error message"
    531     )
    532     # Some exception need extra parameter in inputs. So forget the
    533     # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)

File scan_perform.pyx:418, in aesara.scan.scan_perform.perform()

ValueError: Shape mismatch: A.shape[1] != x.shape[0]
Apply node that caused the error: CGemv{inplace}(AllocEmpty{dtype='float64'}.0, TensorConstant{1.0}, InplaceDimShuffle{x,0}.0, *2-<TensorType(float64, (None,))>, TensorConstant{0.0})
Toposort index: 4
Inputs types: [TensorType(float64, (1,)), TensorType(float64, ()), TensorType(float64, (1, None)), TensorType(float64, (None,)), TensorType(float64, ())]
Inputs shapes: [(1,), (), (1, 6), (0, 42, 6), ()]
Inputs strides: [(8,), (), (96, 16), (2016, 48, 8), ()]
Inputs values: [array([0.]), array(1.), 'not shown', array([], shape=(0, 42, 6), dtype=float64), array(0.)]
Outputs clients: [[InplaceDimShuffle{}(CGemv{inplace}.0)]]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
Apply node that caused the error: for{cpu,scan_fn}(TensorConstant{2}, TensorConstant{[0 1]}, TensorConstant{1}, Elemwise{Add}[(0, 1)].0, AdvancedSubtensor.0)
Toposort index: 84
Inputs types: [TensorType(int64, ()), TensorType(int64, (2,)), TensorType(int64, ()), TensorType(float64, (None, None)), TensorType(float64, (None,))]
Inputs shapes: [(), (2,), (), (6, 2), (0, 42, 6)]
Inputs strides: [(), (8,), (), (16, 8), (2016, 48, 8)]
Inputs values: [array(2), array([0, 1]), array(1), 'not shown', array([], shape=(0, 42, 6), dtype=float64)]
Outputs clients: [[Subtensor{int64}(for{cpu,scan_fn}.0, ScalarConstant{0})]]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

It’s saying that it can’t do the matrix multiplication A[idx == i, :] @ x[:, i] because the shapes don’t match up. Do all of your time series have the same length?

HI @conorhassan

I’ve been running into the same error whilst trying to use aesara.scan. Did you find the issue? (Was it just a shape mismatch as @jessegrabowski suggested?)

Thanks! David