I am trying to fit a model that includes a piecewise linear function. Treating the gradients and intercepts of the the piecewise linear function as independent RVs led to divergence, so instead I wish to draw them from a multivariate distribution, so that each of the parameters will be reasonable in terms of the others. The model looks like this.
with pymc.Model() as Fz_model:
material = pymc.MutableData('material',sample[material_columns].values)
cutting_diameter = pymc.MutableData('cutting_diameter',sample['cutting_diameter'].values)
linear_fits = pymc.MvNormal('linear_fits',mu=numpy.array([-0.014,-0.023,-0.52,0.23]),
cov=numpy.array([[0.125403,-0.001801,-0.317427,0.024030],
[-0.001801,0.000130,0.007583,-0.001422],
[-0.317427,0.007583,0.918049,-0.091864],
[0.024030,-0.001422,-0.091864,0.016044]]),
shape=(n_materials,4))
scales = pymc.Exponential('scales',lam=100.0,size=(n_materials,))
thresholds = pymc.Deterministic('thresholds',(linear_fits[:,3]-linear_fits[:,2])/(linear_fits[:,0]-linear_fits[:1]))
linear_fit = pymc.Deterministic('linear_fit',pymc.math.dot(material,linear_fits))
threshold = pymc.Deterministic('threshold',pymc.math.dot(material,thresholds))
linear0 = pymc.Deterministic('linear0', linear_fit[:,0] * cutting_diameter + linear_fit[:2])
linear1 = pymc.Deterministic('linear1', linear_fit[:,1] * cutting_diameter + linear_fit[:,3])
piecewise_linear = pymc.Deterministic('piecewise_linear',pymc.math.switch(pymc.math.le(cutting_diameter,
threshold),
linear0,
linear1))
scale = pymc.Deterministic('scale',pymc.math.dot(material,scales))
mu = pymc.Deterministic('mu', scale * cutting_diameter * pymc.math.exp(piecewise_linear))
#sigma = pymc.Exponential('sigma',lam=0.01)
Lambda = pymc.Deterministic('Lambda',1.0/mu)
fz = pymc.MutableData('Fz',sample['Fz'])
Fz_obs = pymc.Exponential('Fz_obs',lam=Lambda,observed=fz)
I have confirmed that shape of linear_fits
is what I expect it to be.
Unfortunately, this fails with the following error message
ERROR (aesara.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR:aesara.graph.rewriting.basic:Rewrite failure due to: constant_folding
ERROR (aesara.graph.rewriting.basic): node: Assert{msg=Could not broadcast dimensions}(ScalarConstant{16}, ScalarConstant{False})
ERROR:aesara.graph.rewriting.basic:node: Assert{msg=Could not broadcast dimensions}(ScalarConstant{16}, ScalarConstant{False})
ERROR (aesara.graph.rewriting.basic): TRACEBACK:
ERROR:aesara.graph.rewriting.basic:TRACEBACK:
ERROR (aesara.graph.rewriting.basic): Traceback (most recent call last):
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1933, in process_node
replacements = node_rewriter.transform(fgraph, node)
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1092, in transform
return self.fn(fgraph, node)
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/tensor/rewriting/basic.py", line 1142, in constant_folding
required = thunk()
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/link/c/op.py", line 103, in rval
thunk()
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/link/c/basic.py", line 1788, in __call__
raise exc_value.with_traceback(exc_trace)
AssertionError: Could not broadcast dimensions
ERROR:aesara.graph.rewriting.basic:Traceback (most recent call last):
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1933, in process_node
replacements = node_rewriter.transform(fgraph, node)
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1092, in transform
return self.fn(fgraph, node)
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/tensor/rewriting/basic.py", line 1142, in constant_folding
required = thunk()
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/link/c/op.py", line 103, in rval
thunk()
File "/home/peter/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/link/c/basic.py", line 1788, in __call__
raise exc_value.with_traceback(exc_trace)
AssertionError: Could not broadcast dimensions
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In [12], line 33
30 fz = pymc.MutableData('Fz',sample['Fz'])
31 Fz_obs = pymc.Exponential('Fz_obs',lam=Lambda,observed=fz)
---> 33 pymc.model_to_graphviz(Fz_model)
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/pymc/model_graph.py:433, in model_to_graphviz(model, var_names, formatting)
427 warnings.warn(
428 "Formattings other than 'plain' are currently not supported.",
429 UserWarning,
430 stacklevel=2,
431 )
432 model = pm.modelcontext(model)
--> 433 return ModelGraph(model).make_graph(var_names=var_names, formatting=formatting)
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/pymc/model_graph.py:232, in ModelGraph.make_graph(self, var_names, formatting)
226 raise ImportError(
227 "This function requires the python library graphviz, along with binaries. "
228 "The easiest way to install all of this is by running\n\n"
229 "\tconda install -c conda-forge python-graphviz"
230 )
231 graph = graphviz.Digraph(self.model.name)
--> 232 for plate_label, all_var_names in self.get_plates(var_names).items():
233 if plate_label:
234 # must be preceded by 'cluster' to get a box around it
235 with graph.subgraph(name="cluster" + plate_label) as sub:
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/pymc/model_graph.py:211, in ModelGraph.get_plates(self, var_names)
206 plate_label = " x ".join(
207 f"{d} ({self._eval(self.model.dim_lengths[d])})"
208 for d in self.model.RV_dims[var_name]
209 )
210 else:
--> 211 plate_label = " x ".join(map(str, self._eval(v.shape)))
212 plates[plate_label].add(var_name)
214 return plates
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/pymc/model_graph.py:189, in ModelGraph._eval(self, var)
188 def _eval(self, var):
--> 189 return function([], var, mode="FAST_COMPILE")()
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/__init__.py:317, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
311 fn = orig_function(
312 inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
313 )
314 else:
315 # note: pfunc will also call orig_function -- orig_function is
316 # a choke point that all compilation must pass through
--> 317 fn = pfunc(
318 params=inputs,
319 outputs=outputs,
320 mode=mode,
321 updates=updates,
322 givens=givens,
323 no_default_updates=no_default_updates,
324 accept_inplace=accept_inplace,
325 name=name,
326 rebuild_strict=rebuild_strict,
327 allow_input_downcast=allow_input_downcast,
328 on_unused_input=on_unused_input,
329 profile=profile,
330 output_keys=output_keys,
331 )
332 return fn
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/pfunc.py:371, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
357 profile = ProfileStats(message=profile)
359 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
360 params,
361 outputs,
(...)
368 fgraph=fgraph,
369 )
--> 371 return orig_function(
372 inputs,
373 cloned_outputs,
374 mode,
375 accept_inplace=accept_inplace,
376 name=name,
377 profile=profile,
378 on_unused_input=on_unused_input,
379 output_keys=output_keys,
380 fgraph=fgraph,
381 )
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/types.py:1747, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1745 try:
1746 Maker = getattr(mode, "function_maker", FunctionMaker)
-> 1747 m = Maker(
1748 inputs,
1749 outputs,
1750 mode,
1751 accept_inplace=accept_inplace,
1752 profile=profile,
1753 on_unused_input=on_unused_input,
1754 output_keys=output_keys,
1755 name=name,
1756 fgraph=fgraph,
1757 )
1758 with config.change_flags(compute_test_value="off"):
1759 fn = m.create(defaults)
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/types.py:1517, in FunctionMaker.__init__(self, inputs, outputs, mode, accept_inplace, function_builder, profile, on_unused_input, fgraph, output_keys, name, no_fgraph_prep)
1514 rewriter, linker = mode.optimizer, copy.copy(mode.linker)
1516 if not no_fgraph_prep:
-> 1517 self.prepare_fgraph(
1518 inputs, outputs, found_updates, fgraph, rewriter, linker, profile
1519 )
1521 assert len(fgraph.outputs) == len(outputs + found_updates)
1523 # The 'no_borrow' outputs are the ones for which that we can't
1524 # return the internal storage pointer.
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/compile/function/types.py:1407, in FunctionMaker.prepare_fgraph(inputs, outputs, additional_outputs, fgraph, rewriter, linker, profile)
1401 rewrite_time = None
1403 with config.change_flags(
1404 compute_test_value=config.compute_test_value_opt,
1405 traceback__limit=config.traceback__compile_limit,
1406 ):
-> 1407 rewriter_profile = rewriter(fgraph)
1409 end_rewriter = time.time()
1410 rewrite_time = end_rewriter - start_rewriter
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:135, in GraphRewriter.__call__(self, fgraph)
133 def __call__(self, fgraph):
134 """Rewrite a `FunctionGraph`."""
--> 135 return self.rewrite(fgraph)
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:131, in GraphRewriter.rewrite(self, fgraph, *args, **kwargs)
122 """
123
124 This is meant as a shortcut for the following::
(...)
128
129 """
130 self.add_requirements(fgraph)
--> 131 return self.apply(fgraph, *args, **kwargs)
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:302, in SequentialGraphRewriter.apply(self, fgraph)
300 nb_nodes_before = len(fgraph.apply_nodes)
301 t0 = time.time()
--> 302 sub_prof = rewriter.apply(fgraph)
303 l.append(float(time.time() - t0))
304 sub_profs.append(sub_prof)
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:2477, in EquilibriumGraphRewriter.apply(self, fgraph, start_from)
2475 nb = change_tracker.nb_imported
2476 t_rewrite = time.time()
-> 2477 sub_prof = grewrite.apply(fgraph)
2478 time_rewriters[grewrite] += time.time() - t_rewrite
2479 sub_profs.append(sub_prof)
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:2051, in WalkingGraphRewriter.apply(self, fgraph, start_from)
2049 continue
2050 current_node = node
-> 2051 nb += self.process_node(fgraph, node)
2052 loop_t = time.time() - t0
2053 finally:
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1936, in NodeProcessingGraphRewriter.process_node(self, fgraph, node, node_rewriter)
1934 except Exception as e:
1935 if self.failure_callback is not None:
-> 1936 self.failure_callback(
1937 e, self, [(x, None) for x in node.outputs], node_rewriter, node
1938 )
1939 return False
1940 else:
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1789, in NodeProcessingGraphRewriter.warn_inplace(cls, exc, nav, repl_pairs, node_rewriter, node)
1787 if isinstance(exc, InconsistencyError):
1788 return
-> 1789 return cls.warn(exc, nav, repl_pairs, node_rewriter, node)
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1777, in NodeProcessingGraphRewriter.warn(cls, exc, nav, repl_pairs, node_rewriter, node)
1773 pdb.post_mortem(sys.exc_info()[2])
1774 elif isinstance(exc, AssertionError) or config.on_opt_error == "raise":
1775 # We always crash on AssertionError because something may be
1776 # seriously wrong if such an exception is raised.
-> 1777 raise exc
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1933, in NodeProcessingGraphRewriter.process_node(self, fgraph, node, node_rewriter)
1931 assert node_rewriter is not None
1932 try:
-> 1933 replacements = node_rewriter.transform(fgraph, node)
1934 except Exception as e:
1935 if self.failure_callback is not None:
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py:1092, in FromFunctionNodeRewriter.transform(self, fgraph, node)
1087 if not (
1088 node.op in self._tracks or isinstance(node.op, self._tracked_types)
1089 ):
1090 return False
-> 1092 return self.fn(fgraph, node)
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/tensor/rewriting/basic.py:1142, in constant_folding(fgraph, node)
1139 compute_map[o] = [False]
1141 thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
-> 1142 required = thunk()
1144 # A node whose inputs are all provided should always return successfully
1145 assert not required
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/link/c/op.py:103, in COp.make_c_thunk.<locals>.rval()
101 @is_cthunk_wrapper_type
102 def rval():
--> 103 thunk()
104 for o in node.outputs:
105 compute_map[o][0] = True
File ~/.cache/pypoetry/virtualenvs/tool-insight-HtaDeJzl-py3.10/lib/python3.10/site-packages/aesara/link/c/basic.py:1788, in _CThunk.__call__(self)
1786 print(self.error_storage, file=sys.stderr)
1787 raise
-> 1788 raise exc_value.with_traceback(exc_trace)
AssertionError: Could not broadcast dimensions
It appears that PyMC or Aesara objects to me performing operations of particular columns of linear_fits
- for example, Fz_model.thresholds.shape.eval()
fails.
What do I need to do to make PyMC handle these operations correctly?