Calculation based on columns selected from a multivariate prior fails to broadcast

I am trying to fit a model that includes a piecewise linear function. Treating the gradients and intercepts of the the piecewise linear function as independent RVs led to divergence, so instead I wish to draw them from a multivariate distribution, so that each of the parameters will be reasonable in terms of the others. The model looks like this.

with pymc.Model() as Fz_model:
    material = pymc.MutableData('material',sample[material_columns].values)
    cutting_diameter = pymc.MutableData('cutting_diameter',sample['cutting_diameter'].values)
    linear_fits = pymc.MvNormal('linear_fits',mu=numpy.array([-0.014,-0.023,-0.52,0.23]),
                               cov=numpy.array([[0.125403,-0.001801,-0.317427,0.024030],
                                               [-0.001801,0.000130,0.007583,-0.001422],
                                               [-0.317427,0.007583,0.918049,-0.091864],
                                               [0.024030,-0.001422,-0.091864,0.016044]]),
                                              shape=(n_materials,4))
    scales = pymc.Exponential('scales',lam=100.0,size=(n_materials,))
    thresholds = pymc.Deterministic('thresholds',(linear_fits[:,3]-linear_fits[:,2])/(linear_fits[:,0]-linear_fits[:1]))
    linear_fit = pymc.Deterministic('linear_fit',pymc.math.dot(material,linear_fits))
    threshold = pymc.Deterministic('threshold',pymc.math.dot(material,thresholds))
    linear0 = pymc.Deterministic('linear0', linear_fit[:,0] * cutting_diameter + linear_fit[:2])
    linear1 = pymc.Deterministic('linear1', linear_fit[:,1] * cutting_diameter + linear_fit[:,3])
    piecewise_linear = pymc.Deterministic('piecewise_linear',pymc.math.switch(pymc.math.le(cutting_diameter,
                                                                                          threshold),
                                                                             linear0,
                                                                             linear1))
    scale = pymc.Deterministic('scale',pymc.math.dot(material,scales))
    mu = pymc.Deterministic('mu', scale * cutting_diameter * pymc.math.exp(piecewise_linear))
    #sigma = pymc.Exponential('sigma',lam=0.01)
    Lambda = pymc.Deterministic('Lambda',1.0/mu)
    fz = pymc.MutableData('Fz',sample['Fz'])
    Fz_obs = pymc.Exponential('Fz_obs',lam=Lambda,observed=fz)

I have confirmed that shape of linear_fits is what I expect it to be.

Unfortunately, this fails with the following error message

ERROR (aesara.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR:aesara.graph.rewriting.basic:Rewrite failure due to: constant_folding
ERROR (aesara.graph.rewriting.basic): node: Assert{msg=Could not broadcast dimensions}(ScalarConstant{16}, ScalarConstant{False})
ERROR:aesara.graph.rewriting.basic:node: Assert{msg=Could not broadcast dimensions}(ScalarConstant{16}, ScalarConstant{False})
ERROR (aesara.graph.rewriting.basic): TRACEBACK:
ERROR:aesara.graph.rewriting.basic:TRACEBACK:
ERROR (aesara.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1933, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1092, in transform
    return self.fn(fgraph, node)
  File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/tensor/rewriting/basic.py", line 1142, in constant_folding
    required = thunk()
  File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/link/c/op.py", line 103, in rval
    thunk()
  File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/link/c/basic.py", line 1788, in __call__
    raise exc_value.with_traceback(exc_trace)
AssertionError: Could not broadcast dimensions

ERROR:aesara.graph.rewriting.basic:Traceback (most recent call last):
  File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1933, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1092, in transform
    return self.fn(fgraph, node)
  File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/tensor/rewriting/basic.py", line 1142, in constant_folding
    required = thunk()
  File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/link/c/op.py", line 103, in rval
    thunk()
  File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/link/c/basic.py", line 1788, in __call__
    raise exc_value.with_traceback(exc_trace)
AssertionError: Could not broadcast dimensions

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In [12], line 33
     30     fz = pymc.MutableData('Fz',sample['Fz'])
     31     Fz_obs = pymc.Exponential('Fz_obs',lam=Lambda,observed=fz)
---> 33 pymc.model_to_graphviz(Fz_model)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/pymc/model_graph.py:433, in model_to_graphviz(model, var_names, formatting)
    427     warnings.warn(
    428         "Formattings other than 'plain' are currently not supported.",
    429         UserWarning,
    430         stacklevel=2,
    431     )
    432 model = pm.modelcontext(model)
--> 433 return ModelGraph(model).make_graph(var_names=var_names, formatting=formatting)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/pymc/model_graph.py:232, in ModelGraph.make_graph(self, var_names, formatting)
    226     raise ImportError(
    227         "This function requires the python library graphviz, along with binaries. "
    228         "The easiest way to install all of this is by running\n\n"
    229         "\tconda install -c conda-forge python-graphviz"
    230     )
    231 graph = graphviz.Digraph(self.model.name)
--> 232 for plate_label, all_var_names in self.get_plates(var_names).items():
    233     if plate_label:
    234         # must be preceded by 'cluster' to get a box around it
    235         with graph.subgraph(name="cluster" + plate_label) as sub:

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/pymc/model_graph.py:211, in ModelGraph.get_plates(self, var_names)
    206         plate_label = " x ".join(
    207             f"{d} ({self._eval(self.model.dim_lengths[d])})"
    208             for d in self.model.RV_dims[var_name]
    209         )
    210     else:
--> 211         plate_label = " x ".join(map(str, self._eval(v.shape)))
    212     plates[plate_label].add(var_name)
    214 return plates

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/pymc/model_graph.py:189, in ModelGraph._eval(self, var)
    188 def _eval(self, var):
--> 189     return function([], var, mode="FAST_COMPILE")()

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/__init__.py:317, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
    311     fn = orig_function(
    312         inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
    313     )
    314 else:
    315     # note: pfunc will also call orig_function -- orig_function is
    316     #      a choke point that all compilation must pass through
--> 317     fn = pfunc(
    318         params=inputs,
    319         outputs=outputs,
    320         mode=mode,
    321         updates=updates,
    322         givens=givens,
    323         no_default_updates=no_default_updates,
    324         accept_inplace=accept_inplace,
    325         name=name,
    326         rebuild_strict=rebuild_strict,
    327         allow_input_downcast=allow_input_downcast,
    328         on_unused_input=on_unused_input,
    329         profile=profile,
    330         output_keys=output_keys,
    331     )
    332 return fn

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/pfunc.py:371, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
    357     profile = ProfileStats(message=profile)
    359 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    360     params,
    361     outputs,
   (...)
    368     fgraph=fgraph,
    369 )
--> 371 return orig_function(
    372     inputs,
    373     cloned_outputs,
    374     mode,
    375     accept_inplace=accept_inplace,
    376     name=name,
    377     profile=profile,
    378     on_unused_input=on_unused_input,
    379     output_keys=output_keys,
    380     fgraph=fgraph,
    381 )

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/types.py:1747, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
   1745 try:
   1746     Maker = getattr(mode, "function_maker", FunctionMaker)
-> 1747     m = Maker(
   1748         inputs,
   1749         outputs,
   1750         mode,
   1751         accept_inplace=accept_inplace,
   1752         profile=profile,
   1753         on_unused_input=on_unused_input,
   1754         output_keys=output_keys,
   1755         name=name,
   1756         fgraph=fgraph,
   1757     )
   1758     with config.change_flags(compute_test_value="off"):
   1759         fn = m.create(defaults)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/types.py:1517, in FunctionMaker.__init__(self, inputs, outputs, mode, accept_inplace, function_builder, profile, on_unused_input, fgraph, output_keys, name, no_fgraph_prep)
   1514 rewriter, linker = mode.optimizer, copy.copy(mode.linker)
   1516 if not no_fgraph_prep:
-> 1517     self.prepare_fgraph(
   1518         inputs, outputs, found_updates, fgraph, rewriter, linker, profile
   1519     )
   1521 assert len(fgraph.outputs) == len(outputs + found_updates)
   1523 # The 'no_borrow' outputs are the ones for which that we can't
   1524 # return the internal storage pointer.

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/types.py:1407, in FunctionMaker.prepare_fgraph(inputs, outputs, additional_outputs, fgraph, rewriter, linker, profile)
   1401 rewrite_time = None
   1403 with config.change_flags(
   1404     compute_test_value=config.compute_test_value_opt,
   1405     traceback__limit=config.traceback__compile_limit,
   1406 ):
-> 1407     rewriter_profile = rewriter(fgraph)
   1409     end_rewriter = time.time()
   1410     rewrite_time = end_rewriter - start_rewriter

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:135, in GraphRewriter.__call__(self, fgraph)
    133 def __call__(self, fgraph):
    134     """Rewrite a `FunctionGraph`."""
--> 135     return self.rewrite(fgraph)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:131, in GraphRewriter.rewrite(self, fgraph, *args, **kwargs)
    122 """
    123 
    124 This is meant as a shortcut for the following::
   (...)
    128 
    129 """
    130 self.add_requirements(fgraph)
--> 131 return self.apply(fgraph, *args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:302, in SequentialGraphRewriter.apply(self, fgraph)
    300 nb_nodes_before = len(fgraph.apply_nodes)
    301 t0 = time.time()
--> 302 sub_prof = rewriter.apply(fgraph)
    303 l.append(float(time.time() - t0))
    304 sub_profs.append(sub_prof)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:2477, in EquilibriumGraphRewriter.apply(self, fgraph, start_from)
   2475 nb = change_tracker.nb_imported
   2476 t_rewrite = time.time()
-> 2477 sub_prof = grewrite.apply(fgraph)
   2478 time_rewriters[grewrite] += time.time() - t_rewrite
   2479 sub_profs.append(sub_prof)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:2051, in WalkingGraphRewriter.apply(self, fgraph, start_from)
   2049             continue
   2050         current_node = node
-> 2051         nb += self.process_node(fgraph, node)
   2052     loop_t = time.time() - t0
   2053 finally:

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1936, in NodeProcessingGraphRewriter.process_node(self, fgraph, node, node_rewriter)
   1934 except Exception as e:
   1935     if self.failure_callback is not None:
-> 1936         self.failure_callback(
   1937             e, self, [(x, None) for x in node.outputs], node_rewriter, node
   1938         )
   1939         return False
   1940     else:

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1789, in NodeProcessingGraphRewriter.warn_inplace(cls, exc, nav, repl_pairs, node_rewriter, node)
   1787 if isinstance(exc, InconsistencyError):
   1788     return
-> 1789 return cls.warn(exc, nav, repl_pairs, node_rewriter, node)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1777, in NodeProcessingGraphRewriter.warn(cls, exc, nav, repl_pairs, node_rewriter, node)
   1773     pdb.post_mortem(sys.exc_info()[2])
   1774 elif isinstance(exc, AssertionError) or config.on_opt_error == "raise":
   1775     # We always crash on AssertionError because something may be
   1776     # seriously wrong if such an exception is raised.
-> 1777     raise exc

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1933, in NodeProcessingGraphRewriter.process_node(self, fgraph, node, node_rewriter)
   1931 assert node_rewriter is not None
   1932 try:
-> 1933     replacements = node_rewriter.transform(fgraph, node)
   1934 except Exception as e:
   1935     if self.failure_callback is not None:

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1092, in FromFunctionNodeRewriter.transform(self, fgraph, node)
   1087     if not (
   1088         node.op in self._tracks or isinstance(node.op, self._tracked_types)
   1089     ):
   1090         return False
-> 1092 return self.fn(fgraph, node)

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/tensor/rewriting/basic.py:1142, in constant_folding(fgraph, node)
   1139     compute_map[o] = [False]
   1141 thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
-> 1142 required = thunk()
   1144 # A node whose inputs are all provided should always return successfully
   1145 assert not required

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/link/c/op.py:103, in COp.make_c_thunk.<locals>.rval()
    101 @is_cthunk_wrapper_type
    102 def rval():
--> 103     thunk()
    104     for o in node.outputs:
    105         compute_map[o][0] = True

File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/link/c/basic.py:1788, in _CThunk.__call__(self)
   1786     print(self.error_storage, file=sys.stderr)
   1787     raise
-> 1788 raise exc_value.with_traceback(exc_trace)

AssertionError: Could not broadcast dimensions


It appears that PyMC or Aesara objects to me performing operations of particular columns of linear_fits - for example, Fz_model.thresholds.shape.eval() fails.

What do I need to do to make PyMC handle these operations correctly?

Aren’t you missing a comma here?

That explains it.