# Calculation based on columns selected from a multivariate prior fails to broadcast

I am trying to fit a model that includes a piecewise linear function. Treating the gradients and intercepts of the the piecewise linear function as independent RVs led to divergence, so instead I wish to draw them from a multivariate distribution, so that each of the parameters will be reasonable in terms of the others. The model looks like this.

``````with pymc.Model() as Fz_model:
material = pymc.MutableData('material',sample[material_columns].values)
cutting_diameter = pymc.MutableData('cutting_diameter',sample['cutting_diameter'].values)
linear_fits = pymc.MvNormal('linear_fits',mu=numpy.array([-0.014,-0.023,-0.52,0.23]),
cov=numpy.array([[0.125403,-0.001801,-0.317427,0.024030],
[-0.001801,0.000130,0.007583,-0.001422],
[-0.317427,0.007583,0.918049,-0.091864],
[0.024030,-0.001422,-0.091864,0.016044]]),
shape=(n_materials,4))
scales = pymc.Exponential('scales',lam=100.0,size=(n_materials,))
thresholds = pymc.Deterministic('thresholds',(linear_fits[:,3]-linear_fits[:,2])/(linear_fits[:,0]-linear_fits[:1]))
linear_fit = pymc.Deterministic('linear_fit',pymc.math.dot(material,linear_fits))
threshold = pymc.Deterministic('threshold',pymc.math.dot(material,thresholds))
linear0 = pymc.Deterministic('linear0', linear_fit[:,0] * cutting_diameter + linear_fit[:2])
linear1 = pymc.Deterministic('linear1', linear_fit[:,1] * cutting_diameter + linear_fit[:,3])
piecewise_linear = pymc.Deterministic('piecewise_linear',pymc.math.switch(pymc.math.le(cutting_diameter,
threshold),
linear0,
linear1))
scale = pymc.Deterministic('scale',pymc.math.dot(material,scales))
mu = pymc.Deterministic('mu', scale * cutting_diameter * pymc.math.exp(piecewise_linear))
#sigma = pymc.Exponential('sigma',lam=0.01)
Lambda = pymc.Deterministic('Lambda',1.0/mu)
fz = pymc.MutableData('Fz',sample['Fz'])
Fz_obs = pymc.Exponential('Fz_obs',lam=Lambda,observed=fz)
``````

I have confirmed that shape of `linear_fits` is what I expect it to be.

Unfortunately, this fails with the following error message

``````ERROR (aesara.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR:aesara.graph.rewriting.basic:Rewrite failure due to: constant_folding
ERROR (aesara.graph.rewriting.basic): node: Assert{msg=Could not broadcast dimensions}(ScalarConstant{16}, ScalarConstant{False})
ERROR:aesara.graph.rewriting.basic:node: Assert{msg=Could not broadcast dimensions}(ScalarConstant{16}, ScalarConstant{False})
ERROR (aesara.graph.rewriting.basic): TRACEBACK:
ERROR:aesara.graph.rewriting.basic:TRACEBACK:
ERROR (aesara.graph.rewriting.basic): Traceback (most recent call last):
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1933, in process_node
replacements = node_rewriter.transform(fgraph, node)
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1092, in transform
return self.fn(fgraph, node)
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/tensor/rewriting/basic.py", line 1142, in constant_folding
required = thunk()
thunk()
raise exc_value.with_traceback(exc_trace)

ERROR:aesara.graph.rewriting.basic:Traceback (most recent call last):
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1933, in process_node
replacements = node_rewriter.transform(fgraph, node)
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1092, in transform
return self.fn(fgraph, node)
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/tensor/rewriting/basic.py", line 1142, in constant_folding
required = thunk()
thunk()
raise exc_value.with_traceback(exc_trace)

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In [12], line 33
30     fz = pymc.MutableData('Fz',sample['Fz'])
31     Fz_obs = pymc.Exponential('Fz_obs',lam=Lambda,observed=fz)
---> 33 pymc.model_to_graphviz(Fz_model)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/pymc/model_graph.py:433, in model_to_graphviz(model, var_names, formatting)
427     warnings.warn(
428         "Formattings other than 'plain' are currently not supported.",
429         UserWarning,
430         stacklevel=2,
431     )
432 model = pm.modelcontext(model)
--> 433 return ModelGraph(model).make_graph(var_names=var_names, formatting=formatting)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/pymc/model_graph.py:232, in ModelGraph.make_graph(self, var_names, formatting)
226     raise ImportError(
227         "This function requires the python library graphviz, along with binaries. "
228         "The easiest way to install all of this is by running\n\n"
229         "\tconda install -c conda-forge python-graphviz"
230     )
231 graph = graphviz.Digraph(self.model.name)
--> 232 for plate_label, all_var_names in self.get_plates(var_names).items():
233     if plate_label:
234         # must be preceded by 'cluster' to get a box around it
235         with graph.subgraph(name="cluster" + plate_label) as sub:

206         plate_label = " x ".join(
207             f"{d} ({self._eval(self.model.dim_lengths[d])})"
208             for d in self.model.RV_dims[var_name]
209         )
210     else:
--> 211         plate_label = " x ".join(map(str, self._eval(v.shape)))
214 return plates

188 def _eval(self, var):
--> 189     return function([], var, mode="FAST_COMPILE")()

311     fn = orig_function(
312         inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
313     )
314 else:
315     # note: pfunc will also call orig_function -- orig_function is
316     #      a choke point that all compilation must pass through
--> 317     fn = pfunc(
318         params=inputs,
319         outputs=outputs,
320         mode=mode,
322         givens=givens,
324         accept_inplace=accept_inplace,
325         name=name,
326         rebuild_strict=rebuild_strict,
327         allow_input_downcast=allow_input_downcast,
328         on_unused_input=on_unused_input,
329         profile=profile,
330         output_keys=output_keys,
331     )
332 return fn

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/pfunc.py:371, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
357     profile = ProfileStats(message=profile)
359 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
360     params,
361     outputs,
(...)
368     fgraph=fgraph,
369 )
--> 371 return orig_function(
372     inputs,
373     cloned_outputs,
374     mode,
375     accept_inplace=accept_inplace,
376     name=name,
377     profile=profile,
378     on_unused_input=on_unused_input,
379     output_keys=output_keys,
380     fgraph=fgraph,
381 )

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/types.py:1747, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1745 try:
1746     Maker = getattr(mode, "function_maker", FunctionMaker)
-> 1747     m = Maker(
1748         inputs,
1749         outputs,
1750         mode,
1751         accept_inplace=accept_inplace,
1752         profile=profile,
1753         on_unused_input=on_unused_input,
1754         output_keys=output_keys,
1755         name=name,
1756         fgraph=fgraph,
1757     )
1758     with config.change_flags(compute_test_value="off"):
1759         fn = m.create(defaults)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/types.py:1517, in FunctionMaker.__init__(self, inputs, outputs, mode, accept_inplace, function_builder, profile, on_unused_input, fgraph, output_keys, name, no_fgraph_prep)
1516 if not no_fgraph_prep:
-> 1517     self.prepare_fgraph(
1519     )
1521 assert len(fgraph.outputs) == len(outputs + found_updates)
1523 # The 'no_borrow' outputs are the ones for which that we can't
1524 # return the internal storage pointer.

1401 rewrite_time = None
1403 with config.change_flags(
1404     compute_test_value=config.compute_test_value_opt,
1405     traceback__limit=config.traceback__compile_limit,
1406 ):
-> 1407     rewriter_profile = rewriter(fgraph)
1409     end_rewriter = time.time()
1410     rewrite_time = end_rewriter - start_rewriter

133 def __call__(self, fgraph):
134     """Rewrite a `FunctionGraph`."""
--> 135     return self.rewrite(fgraph)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:131, in GraphRewriter.rewrite(self, fgraph, *args, **kwargs)
122 """
123
124 This is meant as a shortcut for the following::
(...)
128
129 """
--> 131 return self.apply(fgraph, *args, **kwargs)

300 nb_nodes_before = len(fgraph.apply_nodes)
301 t0 = time.time()
--> 302 sub_prof = rewriter.apply(fgraph)
303 l.append(float(time.time() - t0))
304 sub_profs.append(sub_prof)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:2477, in EquilibriumGraphRewriter.apply(self, fgraph, start_from)
2475 nb = change_tracker.nb_imported
2476 t_rewrite = time.time()
-> 2477 sub_prof = grewrite.apply(fgraph)
2478 time_rewriters[grewrite] += time.time() - t_rewrite
2479 sub_profs.append(sub_prof)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:2051, in WalkingGraphRewriter.apply(self, fgraph, start_from)
2049             continue
2050         current_node = node
-> 2051         nb += self.process_node(fgraph, node)
2052     loop_t = time.time() - t0
2053 finally:

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1936, in NodeProcessingGraphRewriter.process_node(self, fgraph, node, node_rewriter)
1934 except Exception as e:
1935     if self.failure_callback is not None:
-> 1936         self.failure_callback(
1937             e, self, [(x, None) for x in node.outputs], node_rewriter, node
1938         )
1939         return False
1940     else:

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1789, in NodeProcessingGraphRewriter.warn_inplace(cls, exc, nav, repl_pairs, node_rewriter, node)
1787 if isinstance(exc, InconsistencyError):
1788     return
-> 1789 return cls.warn(exc, nav, repl_pairs, node_rewriter, node)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1777, in NodeProcessingGraphRewriter.warn(cls, exc, nav, repl_pairs, node_rewriter, node)
1773     pdb.post_mortem(sys.exc_info()[2])
1774 elif isinstance(exc, AssertionError) or config.on_opt_error == "raise":
1775     # We always crash on AssertionError because something may be
1776     # seriously wrong if such an exception is raised.
-> 1777     raise exc

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1933, in NodeProcessingGraphRewriter.process_node(self, fgraph, node, node_rewriter)
1931 assert node_rewriter is not None
1932 try:
-> 1933     replacements = node_rewriter.transform(fgraph, node)
1934 except Exception as e:
1935     if self.failure_callback is not None:

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1092, in FromFunctionNodeRewriter.transform(self, fgraph, node)
1087     if not (
1088         node.op in self._tracks or isinstance(node.op, self._tracked_types)
1089     ):
1090         return False
-> 1092 return self.fn(fgraph, node)

1139     compute_map[o] = [False]
1141 thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
-> 1142 required = thunk()
1144 # A node whose inputs are all provided should always return successfully
1145 assert not required

101 @is_cthunk_wrapper_type
102 def rval():
--> 103     thunk()
104     for o in node.outputs:
105         compute_map[o][0] = True

1786     print(self.error_storage, file=sys.stderr)
1787     raise
-> 1788 raise exc_value.with_traceback(exc_trace)

It appears that PyMC or Aesara objects to me performing operations of particular columns of `linear_fits` - for example, `Fz_model.thresholds.shape.eval()` fails.