I am having trouble getting any model to sample using an aesara scan. Here is a minimum example for reproduction, along with the error I am seeing. Does anyone have advice on what I am doing wrong?
from aesara import scan
import pymc as pm
with pm.Model() as model:
μ = pm.Normal("μ", 0, 2.5)
μ_, _ = scan(
lambda x: x,
outputs_info=μ,
n_steps=2
)
pm.Normal("obs", μ_[-1], 1, observed=[0])
with model:
trace = pm.sample()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [3], in <cell line: 1>()
1 with model:
----> 2 trace = pm.sample()
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/pymc/sampling.py:472, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
469 auto_nuts_init = False
471 initial_points = None
--> 472 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
474 if isinstance(step, list):
475 step = CompoundStep(step)
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/pymc/sampling.py:206, in assign_step_methods(model, step, methods, step_kwargs)
204 if has_gradient:
205 try:
--> 206 tg.grad(model_logpt, var)
207 except (NotImplementedError, tg.NullTypeGradError):
208 has_gradient = False
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:630, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
627 if hasattr(g.type, "dtype"):
628 assert g.type.dtype in aesara.tensor.type.float_dtypes
--> 630 rval = _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
632 for i in range(len(rval)):
633 if isinstance(rval[i].type, NullType):
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1434, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
1431 # end if cache miss
1432 return grad_dict[var]
-> 1434 rval = [access_grad_cache(elem) for elem in wrt]
1436 return rval
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1434, in <listcomp>(.0)
1431 # end if cache miss
1432 return grad_dict[var]
-> 1434 rval = [access_grad_cache(elem) for elem in wrt]
1436 return rval
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1387, in _populate_grad_dict.<locals>.access_grad_cache(var)
1384 for node in node_to_idx:
1385 for idx in node_to_idx[node]:
-> 1387 term = access_term_cache(node)[idx]
1389 if not isinstance(term, Variable):
1390 raise TypeError(
1391 f"{node.op}.grad returned {type(term)}, expected"
1392 " Variable instance."
1393 )
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1059, in _populate_grad_dict.<locals>.access_term_cache(node)
1055 if node not in term_dict:
1057 inputs = node.inputs
-> 1059 output_grads = [access_grad_cache(var) for var in node.outputs]
1061 # list of bools indicating if each output is connected to the cost
1062 outputs_connected = [
1063 not isinstance(g.type, DisconnectedType) for g in output_grads
1064 ]
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1059, in <listcomp>(.0)
1055 if node not in term_dict:
1057 inputs = node.inputs
-> 1059 output_grads = [access_grad_cache(var) for var in node.outputs]
1061 # list of bools indicating if each output is connected to the cost
1062 outputs_connected = [
1063 not isinstance(g.type, DisconnectedType) for g in output_grads
1064 ]
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1387, in _populate_grad_dict.<locals>.access_grad_cache(var)
1384 for node in node_to_idx:
1385 for idx in node_to_idx[node]:
-> 1387 term = access_term_cache(node)[idx]
1389 if not isinstance(term, Variable):
1390 raise TypeError(
1391 f"{node.op}.grad returned {type(term)}, expected"
1392 " Variable instance."
1393 )
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1059, in _populate_grad_dict.<locals>.access_term_cache(node)
1055 if node not in term_dict:
1057 inputs = node.inputs
-> 1059 output_grads = [access_grad_cache(var) for var in node.outputs]
1061 # list of bools indicating if each output is connected to the cost
1062 outputs_connected = [
1063 not isinstance(g.type, DisconnectedType) for g in output_grads
1064 ]
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1059, in <listcomp>(.0)
1055 if node not in term_dict:
1057 inputs = node.inputs
-> 1059 output_grads = [access_grad_cache(var) for var in node.outputs]
1061 # list of bools indicating if each output is connected to the cost
1062 outputs_connected = [
1063 not isinstance(g.type, DisconnectedType) for g in output_grads
1064 ]
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1387, in _populate_grad_dict.<locals>.access_grad_cache(var)
1384 for node in node_to_idx:
1385 for idx in node_to_idx[node]:
-> 1387 term = access_term_cache(node)[idx]
1389 if not isinstance(term, Variable):
1390 raise TypeError(
1391 f"{node.op}.grad returned {type(term)}, expected"
1392 " Variable instance."
1393 )
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1059, in _populate_grad_dict.<locals>.access_term_cache(node)
1055 if node not in term_dict:
1057 inputs = node.inputs
-> 1059 output_grads = [access_grad_cache(var) for var in node.outputs]
1061 # list of bools indicating if each output is connected to the cost
1062 outputs_connected = [
1063 not isinstance(g.type, DisconnectedType) for g in output_grads
1064 ]
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1059, in <listcomp>(.0)
1055 if node not in term_dict:
1057 inputs = node.inputs
-> 1059 output_grads = [access_grad_cache(var) for var in node.outputs]
1061 # list of bools indicating if each output is connected to the cost
1062 outputs_connected = [
1063 not isinstance(g.type, DisconnectedType) for g in output_grads
1064 ]
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1387, in _populate_grad_dict.<locals>.access_grad_cache(var)
1384 for node in node_to_idx:
1385 for idx in node_to_idx[node]:
-> 1387 term = access_term_cache(node)[idx]
1389 if not isinstance(term, Variable):
1390 raise TypeError(
1391 f"{node.op}.grad returned {type(term)}, expected"
1392 " Variable instance."
1393 )
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1214, in _populate_grad_dict.<locals>.access_term_cache(node)
1206 if o_shape != g_shape:
1207 raise ValueError(
1208 "Got a gradient of shape "
1209 + str(o_shape)
1210 + " on an output of shape "
1211 + str(g_shape)
1212 )
-> 1214 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
1216 if input_grads is None:
1217 raise TypeError(
1218 f"{node.op}.grad returned NoneType, expected iterable."
1219 )
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/scan/op.py:2304, in Scan.L_op(self, inputs, outs, dC_douts)
2301 # Restrict the number of grad steps according to
2302 # self.truncate_gradient
2303 if self.truncate_gradient != -1:
-> 2304 grad_steps = minimum(grad_steps, self.truncate_gradient)
2306 self_inputs = self.inputs
2307 self_outputs = self.outputs
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/graph/op.py:294, in Op.__call__(self, *inputs, **kwargs)
252 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
253
254 This method is just a wrapper around :meth:`Op.make_node`.
(...)
291
292 """
293 return_list = kwargs.pop("return_list", False)
--> 294 node = self.make_node(*inputs, **kwargs)
296 if config.compute_test_value != "off":
297 compute_test_value(node)
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/tensor/elemwise.py:462, in Elemwise.make_node(self, *inputs)
456 def make_node(self, *inputs):
457 """
458 If the inputs have different number of dimensions, their shape
459 is left-completed to the greatest number of dimensions with 1s
460 using DimShuffle.
461 """
--> 462 inputs = [as_tensor_variable(i) for i in inputs]
463 out_dtypes, out_broadcastables, inputs = self.get_output_info(
464 DimShuffle, *inputs
465 )
466 outputs = [
467 TensorType(dtype=dtype, shape=broadcastable)()
468 for dtype, broadcastable in zip(out_dtypes, out_broadcastables)
469 ]
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/tensor/elemwise.py:462, in <listcomp>(.0)
456 def make_node(self, *inputs):
457 """
458 If the inputs have different number of dimensions, their shape
459 is left-completed to the greatest number of dimensions with 1s
460 using DimShuffle.
461 """
--> 462 inputs = [as_tensor_variable(i) for i in inputs]
463 out_dtypes, out_broadcastables, inputs = self.get_output_info(
464 DimShuffle, *inputs
465 )
466 outputs = [
467 TensorType(dtype=dtype, shape=broadcastable)()
468 for dtype, broadcastable in zip(out_dtypes, out_broadcastables)
469 ]
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/tensor/__init__.py:42, in as_tensor_variable(x, name, ndim, **kwargs)
10 def as_tensor_variable(
11 x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
12 ) -> "TensorVariable":
13 """Convert `x` into an equivalent `TensorVariable`.
14
15 This function can be used to turn ndarrays, numbers, `Scalar` instances,
(...)
40
41 """
---> 42 return _as_tensor_variable(x, name, ndim, **kwargs)
File ~/miniconda3/envs/notebooks/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)
File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/tensor/basic.py:176, in _as_tensor_bool(x, name, ndim, **kwargs)
174 @_as_tensor_variable.register(bool)
175 def _as_tensor_bool(x, name, ndim, **kwargs):
--> 176 raise TypeError(
177 "Cannot cast True or False as a tensor variable. Please use "
178 "np.array(True) or np.array(False) if you need these constants. "
179 "This error might be caused by using the == operator on "
180 "Variables. v == w does not do what you think it does, "
181 "use aesara.tensor.eq(v, w) instead."
182 )
TypeError: Cannot cast True or False as a tensor variable. Please use np.array(True) or np.array(False) if you need these constants. This error might be caused by using the == operator on Variables. v == w does not do what you think it does, use aesara.tensor.eq(v, w) instead.