Thank you both for your helpful responses!
Junpeng, the first snippet seems to run fine for me, thanks! The second one looks cool, but unfortunately I get a long error message:
ERROR (aesara.graph.opt): Optimization failure due to: local_IncSubtensor_serialize
ERROR (aesara.graph.opt): node: Elemwise{add,no_inplace}(Elemwise{add,no_inplace}.0, AdvancedIncSubtensor{inplace=False, set_instead_of_inc=False}.0)
ERROR (aesara.graph.opt): TRACEBACK:
ERROR (aesara.graph.opt): Traceback (most recent call last):
File "/Users/martin.ingram/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py", line 1861, in process_node
replacements = lopt.transform(fgraph, node)
File "/Users/martin.ingram/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py", line 1066, in transform
return self.fn(fgraph, node)
File "/Users/martin.ingram/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/subtensor_opt.py", line 1203, in local_IncSubtensor_serialize
assert mi.owner.inputs[0].type.is_super(tip.type)
AssertionError
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/vm.py:1245, in VMLinker.make_all(self, profiler, input_storage, output_storage, storage_map)
1241 # no-recycling is done at each VM.__call__ So there is
1242 # no need to cause duplicate c code by passing
1243 # no_recycling here.
1244 thunks.append(
-> 1245 node.op.make_thunk(node, storage_map, compute_map, [], impl=impl)
1246 )
1247 linker_make_thunk_time[node] = time.time() - thunk_start
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/scan/op.py:1534, in Scan.make_thunk(self, node, storage_map, compute_map, no_recycling, impl)
1531 # Analyse the compile inner function to determine which inputs and
1532 # outputs are on the gpu and speed up some checks during the execution
1533 outs_is_tensor = [
-> 1534 isinstance(out, TensorVariable) for out in self.fn.maker.fgraph.outputs
1535 ]
1537 try:
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/scan/op.py:1466, in Scan.fn(self)
1464 profile = self.profile
-> 1466 self._fn = pfunc(
1467 wrapped_inputs,
1468 wrapped_outputs,
1469 mode=self.mode_instance,
1470 accept_inplace=False,
1471 profile=profile,
1472 on_unused_input="ignore",
1473 fgraph=self.fgraph,
1474 )
1476 return self._fn
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/pfunc.py:374, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
362 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
363 params,
364 outputs,
(...)
371 fgraph=fgraph,
372 )
--> 374 return orig_function(
375 inputs,
376 cloned_outputs,
377 mode,
378 accept_inplace=accept_inplace,
379 name=name,
380 profile=profile,
381 on_unused_input=on_unused_input,
382 output_keys=output_keys,
383 fgraph=fgraph,
384 )
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1751, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1750 Maker = getattr(mode, "function_maker", FunctionMaker)
-> 1751 m = Maker(
1752 inputs,
1753 outputs,
1754 mode,
1755 accept_inplace=accept_inplace,
1756 profile=profile,
1757 on_unused_input=on_unused_input,
1758 output_keys=output_keys,
1759 name=name,
1760 fgraph=fgraph,
1761 )
1762 with config.change_flags(compute_test_value="off"):
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1521, in FunctionMaker.__init__(self, inputs, outputs, mode, accept_inplace, function_builder, profile, on_unused_input, fgraph, output_keys, name, no_fgraph_prep)
1520 if not no_fgraph_prep:
-> 1521 self.prepare_fgraph(
1522 inputs, outputs, found_updates, fgraph, optimizer, linker, profile
1523 )
1525 assert len(fgraph.outputs) == len(outputs + found_updates)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1411, in FunctionMaker.prepare_fgraph(inputs, outputs, additional_outputs, fgraph, optimizer, linker, profile)
1407 with config.change_flags(
1408 compute_test_value=config.compute_test_value_opt,
1409 traceback__limit=config.traceback__compile_limit,
1410 ):
-> 1411 optimizer_profile = optimizer(fgraph)
1413 end_optimizer = time.time()
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:111, in GlobalOptimizer.__call__(self, fgraph)
106 """Optimize a `FunctionGraph`.
107
108 This is the same as ``self.optimize(fgraph)``.
109
110 """
--> 111 return self.optimize(fgraph)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:102, in GlobalOptimizer.optimize(self, fgraph, *args, **kwargs)
101 self.add_requirements(fgraph)
--> 102 ret = self.apply(fgraph, *args, **kwargs)
103 return ret
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:279, in SeqOptimizer.apply(self, fgraph)
278 t0 = time.time()
--> 279 sub_prof = optimizer.apply(fgraph)
280 l.append(float(time.time() - t0))
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1971, in TopoOptimizer.apply(self, fgraph, start_from)
1970 current_node = node
-> 1971 nb += self.process_node(fgraph, node)
1972 loop_t = time.time() - t0
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1864, in NavigatorOptimizer.process_node(self, fgraph, node, lopt)
1863 if self.failure_callback is not None:
-> 1864 self.failure_callback(
1865 e, self, [(x, None) for x in node.outputs], lopt, node
1866 )
1867 return False
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1767, in NavigatorOptimizer.warn_inplace(exc, nav, repl_pairs, local_opt, node)
1766 return
-> 1767 return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt, node)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1755, in NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt, node)
1752 elif isinstance(exc, AssertionError) or config.on_opt_error == "raise":
1753 # We always crash on AssertionError because something may be
1754 # seriously wrong if such an exception is raised.
-> 1755 raise exc
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1861, in NavigatorOptimizer.process_node(self, fgraph, node, lopt)
1860 try:
-> 1861 replacements = lopt.transform(fgraph, node)
1862 except Exception as e:
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1066, in FromFunctionLocalOptimizer.transform(self, fgraph, node)
1064 return False
-> 1066 return self.fn(fgraph, node)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/subtensor_opt.py:1203, in local_IncSubtensor_serialize(fgraph, node)
1202 assert o_type.is_super(tip.type)
-> 1203 assert mi.owner.inputs[0].type.is_super(tip.type)
1204 tip = mi.owner.op(tip, *mi.owner.inputs[1:])
AssertionError:
During handling of the above exception, another exception occurred:
AssertionError Traceback (most recent call last)
Input In [5], in <cell line: 20>()
17 vars_replace.append(at.reshape(theta[split_point[i]:split_point[i+1]], v))
18 hvp_clone = aesara.clone_replace(hvp, dict(zip(vars, vars_replace)))
---> 20 hvp_fn = aesara.function([theta, b], [hvp_clone])
21 hvp_fn(q.data, q.data)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/__init__.py:317, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
311 fn = orig_function(
312 inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
313 )
314 else:
315 # note: pfunc will also call orig_function -- orig_function is
316 # a choke point that all compilation must pass through
--> 317 fn = pfunc(
318 params=inputs,
319 outputs=outputs,
320 mode=mode,
321 updates=updates,
322 givens=givens,
323 no_default_updates=no_default_updates,
324 accept_inplace=accept_inplace,
325 name=name,
326 rebuild_strict=rebuild_strict,
327 allow_input_downcast=allow_input_downcast,
328 on_unused_input=on_unused_input,
329 profile=profile,
330 output_keys=output_keys,
331 )
332 return fn
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/pfunc.py:374, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
360 profile = ProfileStats(message=profile)
362 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
363 params,
364 outputs,
(...)
371 fgraph=fgraph,
372 )
--> 374 return orig_function(
375 inputs,
376 cloned_outputs,
377 mode,
378 accept_inplace=accept_inplace,
379 name=name,
380 profile=profile,
381 on_unused_input=on_unused_input,
382 output_keys=output_keys,
383 fgraph=fgraph,
384 )
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1763, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1751 m = Maker(
1752 inputs,
1753 outputs,
(...)
1760 fgraph=fgraph,
1761 )
1762 with config.change_flags(compute_test_value="off"):
-> 1763 fn = m.create(defaults)
1764 finally:
1765 t2 = time.time()
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1656, in FunctionMaker.create(self, input_storage, trustme, storage_map)
1653 start_import_time = aesara.link.c.cmodule.import_time
1655 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1656 _fn, _i, _o = self.linker.make_thunk(
1657 input_storage=input_storage_lists, storage_map=storage_map
1658 )
1660 end_linker = time.time()
1662 linker_time = end_linker - start_linker
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/basic.py:254, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
247 def make_thunk(
248 self,
249 input_storage: Optional["InputStorageType"] = None,
(...)
252 **kwargs,
253 ) -> Tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 254 return self.make_all(
255 input_storage=input_storage,
256 output_storage=output_storage,
257 storage_map=storage_map,
258 )[:3]
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/vm.py:1254, in VMLinker.make_all(self, profiler, input_storage, output_storage, storage_map)
1252 thunks[-1].lazy = False
1253 except Exception:
-> 1254 raise_with_op(fgraph, node)
1256 t1 = time.time()
1258 if self.profile:
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
529 warnings.warn(
530 f"{exc_type} error does not allow us to add an extra error message"
531 )
532 # Some exception need extra parameter in inputs. So forget the
533 # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/vm.py:1245, in VMLinker.make_all(self, profiler, input_storage, output_storage, storage_map)
1240 thunk_start = time.time()
1241 # no-recycling is done at each VM.__call__ So there is
1242 # no need to cause duplicate c code by passing
1243 # no_recycling here.
1244 thunks.append(
-> 1245 node.op.make_thunk(node, storage_map, compute_map, [], impl=impl)
1246 )
1247 linker_make_thunk_time[node] = time.time() - thunk_start
1248 if not hasattr(thunks[-1], "lazy"):
1249 # We don't want all ops maker to think about lazy Ops.
1250 # So if they didn't specify that its lazy or not, it isn't.
1251 # If this member isn't present, it will crash later.
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/scan/op.py:1534, in Scan.make_thunk(self, node, storage_map, compute_map, no_recycling, impl)
1529 node_output_storage = [storage_map[r] for r in node.outputs]
1531 # Analyse the compile inner function to determine which inputs and
1532 # outputs are on the gpu and speed up some checks during the execution
1533 outs_is_tensor = [
-> 1534 isinstance(out, TensorVariable) for out in self.fn.maker.fgraph.outputs
1535 ]
1537 try:
1538 if impl == "py":
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/scan/op.py:1466, in Scan.fn(self)
1463 elif self.profile:
1464 profile = self.profile
-> 1466 self._fn = pfunc(
1467 wrapped_inputs,
1468 wrapped_outputs,
1469 mode=self.mode_instance,
1470 accept_inplace=False,
1471 profile=profile,
1472 on_unused_input="ignore",
1473 fgraph=self.fgraph,
1474 )
1476 return self._fn
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/pfunc.py:374, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
360 profile = ProfileStats(message=profile)
362 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
363 params,
364 outputs,
(...)
371 fgraph=fgraph,
372 )
--> 374 return orig_function(
375 inputs,
376 cloned_outputs,
377 mode,
378 accept_inplace=accept_inplace,
379 name=name,
380 profile=profile,
381 on_unused_input=on_unused_input,
382 output_keys=output_keys,
383 fgraph=fgraph,
384 )
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1751, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1749 try:
1750 Maker = getattr(mode, "function_maker", FunctionMaker)
-> 1751 m = Maker(
1752 inputs,
1753 outputs,
1754 mode,
1755 accept_inplace=accept_inplace,
1756 profile=profile,
1757 on_unused_input=on_unused_input,
1758 output_keys=output_keys,
1759 name=name,
1760 fgraph=fgraph,
1761 )
1762 with config.change_flags(compute_test_value="off"):
1763 fn = m.create(defaults)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1521, in FunctionMaker.__init__(self, inputs, outputs, mode, accept_inplace, function_builder, profile, on_unused_input, fgraph, output_keys, name, no_fgraph_prep)
1518 optimizer, linker = mode.optimizer, copy.copy(mode.linker)
1520 if not no_fgraph_prep:
-> 1521 self.prepare_fgraph(
1522 inputs, outputs, found_updates, fgraph, optimizer, linker, profile
1523 )
1525 assert len(fgraph.outputs) == len(outputs + found_updates)
1527 # The 'no_borrow' outputs are the ones for which that we can't
1528 # return the internal storage pointer.
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1411, in FunctionMaker.prepare_fgraph(inputs, outputs, additional_outputs, fgraph, optimizer, linker, profile)
1405 opt_time = None
1407 with config.change_flags(
1408 compute_test_value=config.compute_test_value_opt,
1409 traceback__limit=config.traceback__compile_limit,
1410 ):
-> 1411 optimizer_profile = optimizer(fgraph)
1413 end_optimizer = time.time()
1414 opt_time = end_optimizer - start_optimizer
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:111, in GlobalOptimizer.__call__(self, fgraph)
105 def __call__(self, fgraph):
106 """Optimize a `FunctionGraph`.
107
108 This is the same as ``self.optimize(fgraph)``.
109
110 """
--> 111 return self.optimize(fgraph)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:102, in GlobalOptimizer.optimize(self, fgraph, *args, **kwargs)
93 """
94
95 This is meant as a shortcut for the following::
(...)
99
100 """
101 self.add_requirements(fgraph)
--> 102 ret = self.apply(fgraph, *args, **kwargs)
103 return ret
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:279, in SeqOptimizer.apply(self, fgraph)
277 nb_nodes_before = len(fgraph.apply_nodes)
278 t0 = time.time()
--> 279 sub_prof = optimizer.apply(fgraph)
280 l.append(float(time.time() - t0))
281 sub_profs.append(sub_prof)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1971, in TopoOptimizer.apply(self, fgraph, start_from)
1969 continue
1970 current_node = node
-> 1971 nb += self.process_node(fgraph, node)
1972 loop_t = time.time() - t0
1973 finally:
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1864, in NavigatorOptimizer.process_node(self, fgraph, node, lopt)
1862 except Exception as e:
1863 if self.failure_callback is not None:
-> 1864 self.failure_callback(
1865 e, self, [(x, None) for x in node.outputs], lopt, node
1866 )
1867 return False
1868 else:
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1767, in NavigatorOptimizer.warn_inplace(exc, nav, repl_pairs, local_opt, node)
1765 if isinstance(exc, InconsistencyError):
1766 return
-> 1767 return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt, node)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1755, in NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt, node)
1751 pdb.post_mortem(sys.exc_info()[2])
1752 elif isinstance(exc, AssertionError) or config.on_opt_error == "raise":
1753 # We always crash on AssertionError because something may be
1754 # seriously wrong if such an exception is raised.
-> 1755 raise exc
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1861, in NavigatorOptimizer.process_node(self, fgraph, node, lopt)
1859 lopt = lopt or self.local_opt
1860 try:
-> 1861 replacements = lopt.transform(fgraph, node)
1862 except Exception as e:
1863 if self.failure_callback is not None:
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1066, in FromFunctionLocalOptimizer.transform(self, fgraph, node)
1061 if not (
1062 node.op in self._tracks or isinstance(node.op, self._tracked_types)
1063 ):
1064 return False
-> 1066 return self.fn(fgraph, node)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/subtensor_opt.py:1203, in local_IncSubtensor_serialize(fgraph, node)
1201 for mi in movable_inputs:
1202 assert o_type.is_super(tip.type)
-> 1203 assert mi.owner.inputs[0].type.is_super(tip.type)
1204 tip = mi.owner.op(tip, *mi.owner.inputs[1:])
1205 # Copy over stacktrace from outputs of the original
1206 # "movable" operation to the new operation.
AssertionError:
Apply node that caused the error: for{cpu,scan_fn&scan_fn&scan_fn&scan_fn&scan_fn&scan_fn&scan_fn}(TensorConstant{175}, TensorConstant{[ 0 1 ..2 173 174]}, TensorConstant{175}, TensorConstant{175}, TensorConstant{175}, TensorConstant{175}, TensorConstant{175}, TensorConstant{175}, TensorConstant{175}, eps_log___log012, InplaceDimShuffle{x}.0, Elemwise{Composite{(i0 - (i1 + (i2 * i3)))}}[(0, 1)].0, Elemwise{sqr,no_inplace}.0, Elemwise{Mul}[(0, 0)].0, Elemwise{true_div,no_inplace}.0, Elemwise{mul,no_inplace}.0, Elemwise{switch,no_inplace}.0, Elemwise{Composite{Switch(i0, (i1 / i2), i3)}}.0, Elemwise{Composite{(Switch(i0, ((i1 * i2) / i3), i4) + (i5 / i2) + (i6 / i2))}}.0, Elemwise{neg,no_inplace}.0, Join.0, Elemwise{sqr,no_inplace}.0, Elemwise{Sqr}[(0, 0)].0, Elemwise{mul,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{sub,no_inplace}.0, sigma_b_log___log134, Elemwise{sqr,no_inplace}.0, Elemwise{neg,no_inplace}.0, Elemwise{true_div,no_inplace}.0, Elemwise{true_div,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{sub,no_inplace}.0, sigma_a_log___log256, Elemwise{sqr,no_inplace}.0, Elemwise{neg,no_inplace}.0, Elemwise{true_div,no_inplace}.0, Elemwise{true_div,no_inplace}.0, Elemwise{Mul}[(0, 0)].0, Elemwise{switch,no_inplace}.0, Elemwise{Composite{(Switch(i0, (i1 * i2), i3) + (i4 / i2) + (i5 / i2))}}[(0, 4)].0, Elemwise{sqr,no_inplace}.0, Elemwise{mul,no_inplace}.0, Elemwise{Mul}[(0, 0)].0, Elemwise{switch,no_inplace}.0, Elemwise{Composite{(Switch(i0, (i1 * i2), i3) + (i4 / i2) + (i5 / i2))}}[(0, 4)].0, Elemwise{sqr,no_inplace}.0, Elemwise{mul,no_inplace}.0)
Toposort index: 95
Inputs types: [TensorType(int64, ()), TensorType(int32, (175,)), TensorType(int64, ()), TensorType(int64, ()), TensorType(int64, ()), TensorType(int64, ()), TensorType(int64, ()), TensorType(int64, ()), TensorType(int64, ()), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (1,))]
HINT: Use a linker other than the C linker to print the inputs' shapes and strides.
HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
Full code here:
import numpy as np
import pandas as pd
import pymc as pm
import aesara
data = pd.read_csv(pm.get_data('radon.csv'))
data['log_radon'] = data['log_radon'].astype(aesara.config.floatX)
county_names = data.county.unique()
county_idx = data.county_code.values.astype('int32')
n_counties = len(data.county.unique())
with pm.Model() as m:
# Hyperpriors for group nodes
mu_a = pm.Normal('mu_a', mu=0., sigma=100.)
sigma_a = pm.HalfNormal('sigma_a', 5.)
mu_b = pm.Normal('mu_b', mu=0., sigma=100.)
sigma_b = pm.HalfNormal('sigma_b', 5.)
# Intercept for each county, distributed around group mean mu_a
# Above we just set mu and sd to a fixed value while here we
# plug in a common group distribution for all a and b (which are
# vectors of length n_counties).
a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_counties)
# Intercept for each county, distributed around group mean mu_a
b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_counties)
# Model error
eps = pm.HalfCauchy('eps', 5.)
radon_est = a[county_idx] + b[county_idx]*data.floor.values
# Data likelihood
radon_like = pm.Normal('radon_like', mu=radon_est,
sigma=eps, observed=data.log_radon)
init_point = m.initial_point()
import aesara.tensor as at
q = pm.blocking.DictToArrayBijection.map({v.name: init_point[v.name] for v in m.vars})
b = at.vector(name='b')
hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
hvp = hessian @ b
# Flatten and replace value (similar to ValueGradFunction in pm.Model)
theta = at.vector(name='theta')
split_point = np.concatenate([
np.asarray([0]),
np.cumsum([
np.prod(v)
for _, v, _ in q.point_map_info
])
], axis=-1).astype(int)
vars_replace = []
for i, (_, v, _) in enumerate(q.point_map_info):
vars_replace.append(at.reshape(theta[split_point[i]:split_point[i+1]], v))
hvp_clone = aesara.clone_replace(hvp, dict(zip(vars, vars_replace)))
hvp_fn = aesara.function([theta, b], [hvp_clone])
hvp_fn(q.data, q.data)
Thanks again for your help! And no problem if this version doesn’t work out, I’m happy to try the first one