Computing the gradient and hvp of the log posterior

Hi all,

I’m doing some work where it would be very useful to access the log posterior’s gradient, and also the vector product of its Hessian with a vector. So these are functions, derived somehow from a PyMC model `m`:

• `val_and_grad`, taking a vector of parameters `theta`, and returning the value and gradient of the log posterior. This is a vector of dimension `D`, where `D` is the number of parameters in the model.
• `hvp`, taking a vector of parameters `theta` and a second vector `b`, both of dimension D, and returning another vector of dimension D, which is the result of computing `H(theta) b`, where H is the Hessian of the log posterior.

Is there a recommended way to do this? I’m able to do it with the JAX backend, but would like to do it with pure PyMC too. Thanks for your help

The standard way is to use `model.logp_dlogp_function`, which internally call ValueGradFunction - this is how the HMC in PyMC called internally:

``````init_point = m.initial_point()

q = pm.blocking.DictToArrayBijection.map({v.name: init_point[v.name] for v in m.vars})
# same output as:
#   m.compile_logp()(init_point), m.compile_dlogp()(init_point)
``````

Not sure we are doing anything particularly smart for computing `hvp`, but using the `model.d2logp` with the input should work:

``````value_var = [m.rvs_to_values.get(var) for var in m.free_RVs]
m.compile_fn(m.d2logp() @ value_var)(init_point)
# same output as:
#   m.compile_d2logp()(init_point) @ q.data
``````
2 Likes

Hi Junpeng,

Thank you very much for replying so quickly and with helpful code!

The first part (value and grad) seems to be working well. I do get a warning:

``````"/Users/martin.ingram/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/model.py:941: FutureWarning: Model.vars has been deprecated. Use Model.value_vars instead."
``````

Should I replace `Model.vars` with `Model.value_vars`?

For the hvp, `m.compile_d2logp()(init_point) @ q.data` works, but unfortunately the (presumably more efficient) code you sent throws an error:

``````---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/basic.py:141, in _as_tensor_Sequence(x, name, ndim, dtype, **kwargs)
140 try:
--> 141     x = type(x)(extract_constants(i) for i in x)
142 except TypeError:

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/basic.py:141, in <genexpr>(.0)
140 try:
--> 141     x = type(x)(extract_constants(i) for i in x)
142 except TypeError:

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/basic.py:136, in _as_tensor_Sequence.<locals>.extract_constants(i)
135     else:
--> 136         raise TypeError
137 else:

TypeError:

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Input In [5], in <cell line: 2>()
1 value_var = [m.rvs_to_values.get(var) for var in m.free_RVs]
----> 2 m.compile_fn(m.d2logp() @ value_var)(init_point)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/var.py:646, in _tensor_py_operators.__dot__(left, right)
645 def __dot__(left, right):
--> 646     return at.math.dense_dot(left, right)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/math.py:2075, in dense_dot(a, b)
2038 def dense_dot(a, b):
2039     """
2040     Computes the dot product of two variables.
2041
(...)
2073
2074     """
-> 2075     a, b = as_tensor_variable(a), as_tensor_variable(b)
2077     if not isinstance(a.type, DenseTensorType) or not isinstance(
2078         b.type, DenseTensorType
2079     ):
2080         raise TypeError("The dense dot product is only supported for dense types")

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/__init__.py:42, in as_tensor_variable(x, name, ndim, **kwargs)
10 def as_tensor_variable(
11     x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
12 ) -> "TensorVariable":
13     """Convert `x` into an equivalent `TensorVariable`.
14
15     This function can be used to turn ndarrays, numbers, `ScalarType` instances,
(...)
40
41     """
---> 42     return _as_tensor_variable(x, name, ndim, **kwargs)

File ~/miniconda3/envs/pymc_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886     raise TypeError(f'{funcname} requires at least '
887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/basic.py:155, in _as_tensor_Sequence(x, name, ndim, dtype, **kwargs)
150         return MakeVector(dtype)(*x)
152     # In this case, we have at least one non-`Constant` term, so we
153     # couldn't get an underlying non-symbolic sequence of objects and we to
154     # symbolically join terms.
--> 155     return stack(x)
157 return constant(x, name=name, ndim=ndim, dtype=dtype)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/basic.py:2780, in stack(*tensors, **kwargs)
2778     dtype = aes.upcast(*[i.dtype for i in tensors])
2779     return MakeVector(dtype)(*tensors)
-> 2780 return join(axis, *[shape_padaxis(t, axis) for t in tensors])

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/basic.py:2631, in join(axis, *tensors_list)
2629     return tensors_list[0]
2630 else:
-> 2631     return join_(axis, *tensors_list)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/op.py:296, in Op.__call__(self, *inputs, **kwargs)
254 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
255
256 This method is just a wrapper around :meth:`Op.make_node`.
(...)
293
294 """
295 return_list = kwargs.pop("return_list", False)
--> 296 node = self.make_node(*inputs, **kwargs)
298 if config.compute_test_value != "off":
299     compute_test_value(node)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/basic.py:2418, in Join.make_node(self, axis, *tensors)
2415         bcastable = [False] * len(tensors[0].type.broadcastable)
2417 if not builtins.all(x.ndim == len(bcastable) for x in tensors):
-> 2418     raise TypeError(
2419         "Only tensors with the same number of dimensions can be joined"
2420     )
2422 inputs = [as_tensor_variable(axis)] + list(tensors)
2424 if inputs[0].type.dtype not in int_dtypes:

TypeError: Only tensors with the same number of dimensions can be joined
``````

My version of PyMC may not be the very latest, could that be the problem, or is something else going on?

Here’s the full code:

``````import numpy as np
import pandas as pd
import pymc as pm
import aesara

county_names = data.county.unique()
county_idx = data.county_code.values.astype('int32')

n_counties = len(data.county.unique())

with pm.Model() as m:
# Hyperpriors for group nodes
mu_a = pm.Normal('mu_a', mu=0., sigma=100.)
sigma_a = pm.HalfNormal('sigma_a', 5.)
mu_b = pm.Normal('mu_b', mu=0., sigma=100.)
sigma_b = pm.HalfNormal('sigma_b', 5.)

# Intercept for each county, distributed around group mean mu_a
# Above we just set mu and sd to a fixed value while here we
# plug in a common group distribution for all a and b (which are
# vectors of length n_counties).
a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_counties)
# Intercept for each county, distributed around group mean mu_a
b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_counties)

# Model error
eps = pm.HalfCauchy('eps', 5.)

# Data likelihood

init_point = m.initial_point()

value_var = [m.rvs_to_values.get(var) for var in m.free_RVs]
m.compile_fn(m.d2logp() @ value_var)(init_point)
``````

Thanks again

Yes!

You are right, it doesnt work because I was testing with a model that contains a bunch of scalars. Try this:

``````hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
value_var = at.concatenate([at.flatten(v) for v in vars], axis=0)
# value_var = [m.rvs_to_values.get(var) for var in m.free_RVs]
hvp = m.compile_fn(hessian @ value_var)
hvp(init_point)
``````
2 Likes

BTW, you can also check out pymc/aesaraf.py at main · pymc-devs/pymc · GitHub for more information around gradient in PyMC.

2 Likes

Thanks a lot Junpeng! This runs now. I’m just stuck on one point: the hvp should be a function of two variables:

``````hvp(x, y)
``````

so that the Hessian is computed at `x` and then multiplied with the vector `y`. The snippet you sent computes the hvp as a function of only one variable. I think it does

``````hvp(init_point, y)
``````

How would I feed a different `x` to the hvp?

Thanks again for your help on this

Oh right, I missed that part. In that case, probably easier to compile an aesara function:

``````b = at.vector(name='b')
hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
hvp = hessian @ b

hvp_fn = aesara.function(vars + [b], [hvp])
hvp_fn(*init_point.values(), q.data)
``````

Maybe even clone the subgraph so you can work with vector `theta` and `b` directly:

``````b = at.vector(name='b')
hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
hvp = hessian @ b

# Flatten and replace value (similar to ValueGradFunction in pm.Model)
theta = at.vector(name='theta')
split_point = np.concatenate([
np.asarray([0]),
np.cumsum([
np.prod(v)
for _, v, _ in q.point_map_info
])
], axis=-1).astype(int)
vars_replace = []
for i, (_, v, _) in enumerate(q.point_map_info):
vars_replace.append(at.reshape(theta[split_point[i]:split_point[i+1]], v))
hvp_clone = aesara.clone_replace(hvp, dict(zip(vars, vars_replace)))

hvp_fn = aesara.function([theta, b], [hvp_clone])
hvp_fn(q.data, q.data)
``````
1 Like

In addition to what @junpenglao said, I suggest you always use `pymc.aesaraf.compile_pymc` instead of `aesara.function` directly as it automatically introduces PyMC specific rewrites (e.g. replace logp assertions by -inf switches).

2 Likes

Junpeng, the first snippet seems to run fine for me, thanks! The second one looks cool, but unfortunately I get a long error message:

``````ERROR (aesara.graph.opt): Optimization failure due to: local_IncSubtensor_serialize
ERROR (aesara.graph.opt): TRACEBACK:
ERROR (aesara.graph.opt): Traceback (most recent call last):
File "/Users/martin.ingram/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py", line 1861, in process_node
replacements = lopt.transform(fgraph, node)
File "/Users/martin.ingram/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py", line 1066, in transform
return self.fn(fgraph, node)
File "/Users/martin.ingram/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/subtensor_opt.py", line 1203, in local_IncSubtensor_serialize
assert mi.owner.inputs[0].type.is_super(tip.type)
AssertionError

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
1241 # no-recycling is done at each VM.__call__ So there is
1242 # no need to cause duplicate c code by passing
1243 # no_recycling here.
1244 thunks.append(
-> 1245     node.op.make_thunk(node, storage_map, compute_map, [], impl=impl)
1246 )
1247 linker_make_thunk_time[node] = time.time() - thunk_start

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/scan/op.py:1534, in Scan.make_thunk(self, node, storage_map, compute_map, no_recycling, impl)
1531 # Analyse the compile inner function to determine which inputs and
1532 # outputs are on the gpu and speed up some checks during the execution
1533 outs_is_tensor = [
-> 1534     isinstance(out, TensorVariable) for out in self.fn.maker.fgraph.outputs
1535 ]
1537 try:

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/scan/op.py:1466, in Scan.fn(self)
1464     profile = self.profile
-> 1466 self._fn = pfunc(
1467     wrapped_inputs,
1468     wrapped_outputs,
1469     mode=self.mode_instance,
1470     accept_inplace=False,
1471     profile=profile,
1472     on_unused_input="ignore",
1473     fgraph=self.fgraph,
1474 )
1476 return self._fn

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/pfunc.py:374, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
362 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
363     params,
364     outputs,
(...)
371     fgraph=fgraph,
372 )
--> 374 return orig_function(
375     inputs,
376     cloned_outputs,
377     mode,
378     accept_inplace=accept_inplace,
379     name=name,
380     profile=profile,
381     on_unused_input=on_unused_input,
382     output_keys=output_keys,
383     fgraph=fgraph,
384 )

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1751, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1750 Maker = getattr(mode, "function_maker", FunctionMaker)
-> 1751 m = Maker(
1752     inputs,
1753     outputs,
1754     mode,
1755     accept_inplace=accept_inplace,
1756     profile=profile,
1757     on_unused_input=on_unused_input,
1758     output_keys=output_keys,
1759     name=name,
1760     fgraph=fgraph,
1761 )
1762 with config.change_flags(compute_test_value="off"):

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1521, in FunctionMaker.__init__(self, inputs, outputs, mode, accept_inplace, function_builder, profile, on_unused_input, fgraph, output_keys, name, no_fgraph_prep)
1520 if not no_fgraph_prep:
-> 1521     self.prepare_fgraph(
1523     )
1525 assert len(fgraph.outputs) == len(outputs + found_updates)

1407 with config.change_flags(
1408     compute_test_value=config.compute_test_value_opt,
1409     traceback__limit=config.traceback__compile_limit,
1410 ):
-> 1411     optimizer_profile = optimizer(fgraph)
1413     end_optimizer = time.time()

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:111, in GlobalOptimizer.__call__(self, fgraph)
106 """Optimize a `FunctionGraph`.
107
108 This is the same as ``self.optimize(fgraph)``.
109
110 """
--> 111 return self.optimize(fgraph)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:102, in GlobalOptimizer.optimize(self, fgraph, *args, **kwargs)
--> 102 ret = self.apply(fgraph, *args, **kwargs)
103 return ret

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:279, in SeqOptimizer.apply(self, fgraph)
278 t0 = time.time()
--> 279 sub_prof = optimizer.apply(fgraph)
280 l.append(float(time.time() - t0))

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1971, in TopoOptimizer.apply(self, fgraph, start_from)
1970     current_node = node
-> 1971     nb += self.process_node(fgraph, node)
1972 loop_t = time.time() - t0

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1864, in NavigatorOptimizer.process_node(self, fgraph, node, lopt)
1863 if self.failure_callback is not None:
-> 1864     self.failure_callback(
1865         e, self, [(x, None) for x in node.outputs], lopt, node
1866     )
1867     return False

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1767, in NavigatorOptimizer.warn_inplace(exc, nav, repl_pairs, local_opt, node)
1766     return
-> 1767 return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt, node)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1755, in NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt, node)
1752 elif isinstance(exc, AssertionError) or config.on_opt_error == "raise":
1753     # We always crash on AssertionError because something may be
1754     # seriously wrong if such an exception is raised.
-> 1755     raise exc

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1861, in NavigatorOptimizer.process_node(self, fgraph, node, lopt)
1860 try:
-> 1861     replacements = lopt.transform(fgraph, node)
1862 except Exception as e:

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1066, in FromFunctionLocalOptimizer.transform(self, fgraph, node)
1064         return False
-> 1066 return self.fn(fgraph, node)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/subtensor_opt.py:1203, in local_IncSubtensor_serialize(fgraph, node)
1202 assert o_type.is_super(tip.type)
-> 1203 assert mi.owner.inputs[0].type.is_super(tip.type)
1204 tip = mi.owner.op(tip, *mi.owner.inputs[1:])

AssertionError:

During handling of the above exception, another exception occurred:

AssertionError                            Traceback (most recent call last)
Input In [5], in <cell line: 20>()
17     vars_replace.append(at.reshape(theta[split_point[i]:split_point[i+1]], v))
18 hvp_clone = aesara.clone_replace(hvp, dict(zip(vars, vars_replace)))
---> 20 hvp_fn = aesara.function([theta, b], [hvp_clone])
21 hvp_fn(q.data, q.data)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/__init__.py:317, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
311     fn = orig_function(
312         inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
313     )
314 else:
315     # note: pfunc will also call orig_function -- orig_function is
316     #      a choke point that all compilation must pass through
--> 317     fn = pfunc(
318         params=inputs,
319         outputs=outputs,
320         mode=mode,
322         givens=givens,
324         accept_inplace=accept_inplace,
325         name=name,
326         rebuild_strict=rebuild_strict,
327         allow_input_downcast=allow_input_downcast,
328         on_unused_input=on_unused_input,
329         profile=profile,
330         output_keys=output_keys,
331     )
332 return fn

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/pfunc.py:374, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
360     profile = ProfileStats(message=profile)
362 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
363     params,
364     outputs,
(...)
371     fgraph=fgraph,
372 )
--> 374 return orig_function(
375     inputs,
376     cloned_outputs,
377     mode,
378     accept_inplace=accept_inplace,
379     name=name,
380     profile=profile,
381     on_unused_input=on_unused_input,
382     output_keys=output_keys,
383     fgraph=fgraph,
384 )

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1763, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1751     m = Maker(
1752         inputs,
1753         outputs,
(...)
1760         fgraph=fgraph,
1761     )
1762     with config.change_flags(compute_test_value="off"):
-> 1763         fn = m.create(defaults)
1764 finally:
1765     t2 = time.time()

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1656, in FunctionMaker.create(self, input_storage, trustme, storage_map)
1655 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1656     _fn, _i, _o = self.linker.make_thunk(
1657         input_storage=input_storage_lists, storage_map=storage_map
1658     )

247 def make_thunk(
248     self,
249     input_storage: Optional["InputStorageType"] = None,
(...)
252     **kwargs,
253 ) -> Tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 254     return self.make_all(
255         input_storage=input_storage,
256         output_storage=output_storage,
257         storage_map=storage_map,
258     )[:3]

1252             thunks[-1].lazy = False
1253     except Exception:
-> 1254         raise_with_op(fgraph, node)
1256 t1 = time.time()
1258 if self.profile:

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
529     warnings.warn(
530         f"{exc_type} error does not allow us to add an extra error message"
531     )
532     # Some exception need extra parameter in inputs. So forget the
533     # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)

1240 thunk_start = time.time()
1241 # no-recycling is done at each VM.__call__ So there is
1242 # no need to cause duplicate c code by passing
1243 # no_recycling here.
1244 thunks.append(
-> 1245     node.op.make_thunk(node, storage_map, compute_map, [], impl=impl)
1246 )
1247 linker_make_thunk_time[node] = time.time() - thunk_start
1248 if not hasattr(thunks[-1], "lazy"):
1249     # We don't want all ops maker to think about lazy Ops.
1250     # So if they didn't specify that its lazy or not, it isn't.
1251     # If this member isn't present, it will crash later.

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/scan/op.py:1534, in Scan.make_thunk(self, node, storage_map, compute_map, no_recycling, impl)
1529 node_output_storage = [storage_map[r] for r in node.outputs]
1531 # Analyse the compile inner function to determine which inputs and
1532 # outputs are on the gpu and speed up some checks during the execution
1533 outs_is_tensor = [
-> 1534     isinstance(out, TensorVariable) for out in self.fn.maker.fgraph.outputs
1535 ]
1537 try:
1538     if impl == "py":

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/scan/op.py:1466, in Scan.fn(self)
1463 elif self.profile:
1464     profile = self.profile
-> 1466 self._fn = pfunc(
1467     wrapped_inputs,
1468     wrapped_outputs,
1469     mode=self.mode_instance,
1470     accept_inplace=False,
1471     profile=profile,
1472     on_unused_input="ignore",
1473     fgraph=self.fgraph,
1474 )
1476 return self._fn

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/pfunc.py:374, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
360     profile = ProfileStats(message=profile)
362 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
363     params,
364     outputs,
(...)
371     fgraph=fgraph,
372 )
--> 374 return orig_function(
375     inputs,
376     cloned_outputs,
377     mode,
378     accept_inplace=accept_inplace,
379     name=name,
380     profile=profile,
381     on_unused_input=on_unused_input,
382     output_keys=output_keys,
383     fgraph=fgraph,
384 )

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1751, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1749 try:
1750     Maker = getattr(mode, "function_maker", FunctionMaker)
-> 1751     m = Maker(
1752         inputs,
1753         outputs,
1754         mode,
1755         accept_inplace=accept_inplace,
1756         profile=profile,
1757         on_unused_input=on_unused_input,
1758         output_keys=output_keys,
1759         name=name,
1760         fgraph=fgraph,
1761     )
1762     with config.change_flags(compute_test_value="off"):
1763         fn = m.create(defaults)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:1521, in FunctionMaker.__init__(self, inputs, outputs, mode, accept_inplace, function_builder, profile, on_unused_input, fgraph, output_keys, name, no_fgraph_prep)
1520 if not no_fgraph_prep:
-> 1521     self.prepare_fgraph(
1523     )
1525 assert len(fgraph.outputs) == len(outputs + found_updates)
1527 # The 'no_borrow' outputs are the ones for which that we can't
1528 # return the internal storage pointer.

1405 opt_time = None
1407 with config.change_flags(
1408     compute_test_value=config.compute_test_value_opt,
1409     traceback__limit=config.traceback__compile_limit,
1410 ):
-> 1411     optimizer_profile = optimizer(fgraph)
1413     end_optimizer = time.time()
1414     opt_time = end_optimizer - start_optimizer

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:111, in GlobalOptimizer.__call__(self, fgraph)
105 def __call__(self, fgraph):
106     """Optimize a `FunctionGraph`.
107
108     This is the same as ``self.optimize(fgraph)``.
109
110     """
--> 111     return self.optimize(fgraph)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:102, in GlobalOptimizer.optimize(self, fgraph, *args, **kwargs)
93 """
94
95 This is meant as a shortcut for the following::
(...)
99
100 """
--> 102 ret = self.apply(fgraph, *args, **kwargs)
103 return ret

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:279, in SeqOptimizer.apply(self, fgraph)
277 nb_nodes_before = len(fgraph.apply_nodes)
278 t0 = time.time()
--> 279 sub_prof = optimizer.apply(fgraph)
280 l.append(float(time.time() - t0))
281 sub_profs.append(sub_prof)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1971, in TopoOptimizer.apply(self, fgraph, start_from)
1969             continue
1970         current_node = node
-> 1971         nb += self.process_node(fgraph, node)
1972     loop_t = time.time() - t0
1973 finally:

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1864, in NavigatorOptimizer.process_node(self, fgraph, node, lopt)
1862 except Exception as e:
1863     if self.failure_callback is not None:
-> 1864         self.failure_callback(
1865             e, self, [(x, None) for x in node.outputs], lopt, node
1866         )
1867         return False
1868     else:

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1767, in NavigatorOptimizer.warn_inplace(exc, nav, repl_pairs, local_opt, node)
1765 if isinstance(exc, InconsistencyError):
1766     return
-> 1767 return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt, node)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1755, in NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt, node)
1751     pdb.post_mortem(sys.exc_info()[2])
1752 elif isinstance(exc, AssertionError) or config.on_opt_error == "raise":
1753     # We always crash on AssertionError because something may be
1754     # seriously wrong if such an exception is raised.
-> 1755     raise exc

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1861, in NavigatorOptimizer.process_node(self, fgraph, node, lopt)
1859 lopt = lopt or self.local_opt
1860 try:
-> 1861     replacements = lopt.transform(fgraph, node)
1862 except Exception as e:
1863     if self.failure_callback is not None:

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/opt.py:1066, in FromFunctionLocalOptimizer.transform(self, fgraph, node)
1061     if not (
1062         node.op in self._tracks or isinstance(node.op, self._tracked_types)
1063     ):
1064         return False
-> 1066 return self.fn(fgraph, node)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/subtensor_opt.py:1203, in local_IncSubtensor_serialize(fgraph, node)
1201 for mi in movable_inputs:
1202     assert o_type.is_super(tip.type)
-> 1203     assert mi.owner.inputs[0].type.is_super(tip.type)
1204     tip = mi.owner.op(tip, *mi.owner.inputs[1:])
1205     # Copy over stacktrace from outputs of the original
1206     # "movable" operation to the new operation.

AssertionError:
Apply node that caused the error: for{cpu,scan_fn&scan_fn&scan_fn&scan_fn&scan_fn&scan_fn&scan_fn}(TensorConstant{175}, TensorConstant{[  0   1  ..2 173 174]}, TensorConstant{175}, TensorConstant{175}, TensorConstant{175}, TensorConstant{175}, TensorConstant{175}, TensorConstant{175}, TensorConstant{175}, eps_log___log012, InplaceDimShuffle{x}.0, Elemwise{Composite{(i0 - (i1 + (i2 * i3)))}}[(0, 1)].0, Elemwise{sqr,no_inplace}.0, Elemwise{Mul}[(0, 0)].0, Elemwise{true_div,no_inplace}.0, Elemwise{mul,no_inplace}.0, Elemwise{switch,no_inplace}.0, Elemwise{Composite{Switch(i0, (i1 / i2), i3)}}.0, Elemwise{Composite{(Switch(i0, ((i1 * i2) / i3), i4) + (i5 / i2) + (i6 / i2))}}.0, Elemwise{neg,no_inplace}.0, Join.0, Elemwise{sqr,no_inplace}.0, Elemwise{Sqr}[(0, 0)].0, Elemwise{mul,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{sub,no_inplace}.0, sigma_b_log___log134, Elemwise{sqr,no_inplace}.0, Elemwise{neg,no_inplace}.0, Elemwise{true_div,no_inplace}.0, Elemwise{true_div,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{sub,no_inplace}.0, sigma_a_log___log256, Elemwise{sqr,no_inplace}.0, Elemwise{neg,no_inplace}.0, Elemwise{true_div,no_inplace}.0, Elemwise{true_div,no_inplace}.0, Elemwise{Mul}[(0, 0)].0, Elemwise{switch,no_inplace}.0, Elemwise{Composite{(Switch(i0, (i1 * i2), i3) + (i4 / i2) + (i5 / i2))}}[(0, 4)].0, Elemwise{sqr,no_inplace}.0, Elemwise{mul,no_inplace}.0, Elemwise{Mul}[(0, 0)].0, Elemwise{switch,no_inplace}.0, Elemwise{Composite{(Switch(i0, (i1 * i2), i3) + (i4 / i2) + (i5 / i2))}}[(0, 4)].0, Elemwise{sqr,no_inplace}.0, Elemwise{mul,no_inplace}.0)
Toposort index: 95
Inputs types: [TensorType(int64, ()), TensorType(int32, (175,)), TensorType(int64, ()), TensorType(int64, ()), TensorType(int64, ()), TensorType(int64, ()), TensorType(int64, ()), TensorType(int64, ()), TensorType(int64, ()), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (1,))]

HINT: Use a linker other than the C linker to print the inputs' shapes and strides.
HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
``````

Full code here:

``````import numpy as np
import pandas as pd
import pymc as pm
import aesara

county_names = data.county.unique()
county_idx = data.county_code.values.astype('int32')

n_counties = len(data.county.unique())

with pm.Model() as m:
# Hyperpriors for group nodes
mu_a = pm.Normal('mu_a', mu=0., sigma=100.)
sigma_a = pm.HalfNormal('sigma_a', 5.)
mu_b = pm.Normal('mu_b', mu=0., sigma=100.)
sigma_b = pm.HalfNormal('sigma_b', 5.)

# Intercept for each county, distributed around group mean mu_a
# Above we just set mu and sd to a fixed value while here we
# plug in a common group distribution for all a and b (which are
# vectors of length n_counties).
a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_counties)
# Intercept for each county, distributed around group mean mu_a
b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_counties)

# Model error
eps = pm.HalfCauchy('eps', 5.)

# Data likelihood

init_point = m.initial_point()

import aesara.tensor as at

q = pm.blocking.DictToArrayBijection.map({v.name: init_point[v.name] for v in m.vars})

b = at.vector(name='b')
hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
hvp = hessian @ b

# Flatten and replace value (similar to ValueGradFunction in pm.Model)
theta = at.vector(name='theta')
split_point = np.concatenate([
np.asarray([0]),
np.cumsum([
np.prod(v)
for _, v, _ in q.point_map_info
])
], axis=-1).astype(int)
vars_replace = []
for i, (_, v, _) in enumerate(q.point_map_info):
vars_replace.append(at.reshape(theta[split_point[i]:split_point[i+1]], v))
hvp_clone = aesara.clone_replace(hvp, dict(zip(vars, vars_replace)))

hvp_fn = aesara.function([theta, b], [hvp_clone])
hvp_fn(q.data, q.data)
``````

Thanks again for your help! And no problem if this version doesn’t work out, I’m happy to try the first one

Seems like a bug? As it works with `m.logp()` and `m.dlogp()`:

``````theta = at.vector(name='theta')
# theta.tag.test_value = q.data
logp = m.logp()
vars = pm.aesaraf.cont_inputs(logp)

split_point = np.concatenate([
np.asarray([0]),
np.cumsum([
np.prod(v)
for _, v, _ in q.point_map_info
])
], axis=-1).astype(int)
vars_replace = []
for i, (_, v, _) in enumerate(q.point_map_info):
vars_replace.append(at.reshape(theta[split_point[i]:split_point[i+1]], v))
logp_clone = aesara.clone_replace(logp, dict(zip(vars, vars_replace)))

logp_fn = pm.aesaraf.compile_pymc([theta], [logp_clone])
logp_fn(q.data)
``````

Maybe @ricardoV94 could take a look?

Tried a few different way like cloning the logp and then taking the gradient of gradient but keep getting the same error, so I cannot find any easy fix for this.

Hi Junpeng, thanks for checking. No problem – I think I should be able to make the code you wrote earlier work for my purposes. So I’m marking the thread as solved, but if this is an interesting bug to explore, of course I’d be interested in any updates down the road.