Problem using simple aesara scan in PyMC model

I am having trouble getting any model to sample using an aesara scan. Here is a minimum example for reproduction, along with the error I am seeing. Does anyone have advice on what I am doing wrong?

from aesara import scan
import pymc as pm

with pm.Model() as model:
    μ = pm.Normal("μ", 0, 2.5)
    
    μ_, _ = scan(
        lambda x: x,
        outputs_info=μ,
        n_steps=2
    )
    
    pm.Normal("obs", μ_[-1], 1, observed=[0])

with model:
    trace = pm.sample()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [3], in <cell line: 1>()
      1 with model:
----> 2     trace = pm.sample()

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/pymc/sampling.py:472, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    469         auto_nuts_init = False
    471 initial_points = None
--> 472 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    474 if isinstance(step, list):
    475     step = CompoundStep(step)

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/pymc/sampling.py:206, in assign_step_methods(model, step, methods, step_kwargs)
    204 if has_gradient:
    205     try:
--> 206         tg.grad(model_logpt, var)
    207     except (NotImplementedError, tg.NullTypeGradError):
    208         has_gradient = False

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:630, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    627     if hasattr(g.type, "dtype"):
    628         assert g.type.dtype in aesara.tensor.type.float_dtypes
--> 630 rval = _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
    632 for i in range(len(rval)):
    633     if isinstance(rval[i].type, NullType):

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1434, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1431     # end if cache miss
   1432     return grad_dict[var]
-> 1434 rval = [access_grad_cache(elem) for elem in wrt]
   1436 return rval

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1434, in <listcomp>(.0)
   1431     # end if cache miss
   1432     return grad_dict[var]
-> 1434 rval = [access_grad_cache(elem) for elem in wrt]
   1436 return rval

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1387, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1384 for node in node_to_idx:
   1385     for idx in node_to_idx[node]:
-> 1387         term = access_term_cache(node)[idx]
   1389         if not isinstance(term, Variable):
   1390             raise TypeError(
   1391                 f"{node.op}.grad returned {type(term)}, expected"
   1392                 " Variable instance."
   1393             )

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1059, in _populate_grad_dict.<locals>.access_term_cache(node)
   1055 if node not in term_dict:
   1057     inputs = node.inputs
-> 1059     output_grads = [access_grad_cache(var) for var in node.outputs]
   1061     # list of bools indicating if each output is connected to the cost
   1062     outputs_connected = [
   1063         not isinstance(g.type, DisconnectedType) for g in output_grads
   1064     ]

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1059, in <listcomp>(.0)
   1055 if node not in term_dict:
   1057     inputs = node.inputs
-> 1059     output_grads = [access_grad_cache(var) for var in node.outputs]
   1061     # list of bools indicating if each output is connected to the cost
   1062     outputs_connected = [
   1063         not isinstance(g.type, DisconnectedType) for g in output_grads
   1064     ]

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1387, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1384 for node in node_to_idx:
   1385     for idx in node_to_idx[node]:
-> 1387         term = access_term_cache(node)[idx]
   1389         if not isinstance(term, Variable):
   1390             raise TypeError(
   1391                 f"{node.op}.grad returned {type(term)}, expected"
   1392                 " Variable instance."
   1393             )

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1059, in _populate_grad_dict.<locals>.access_term_cache(node)
   1055 if node not in term_dict:
   1057     inputs = node.inputs
-> 1059     output_grads = [access_grad_cache(var) for var in node.outputs]
   1061     # list of bools indicating if each output is connected to the cost
   1062     outputs_connected = [
   1063         not isinstance(g.type, DisconnectedType) for g in output_grads
   1064     ]

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1059, in <listcomp>(.0)
   1055 if node not in term_dict:
   1057     inputs = node.inputs
-> 1059     output_grads = [access_grad_cache(var) for var in node.outputs]
   1061     # list of bools indicating if each output is connected to the cost
   1062     outputs_connected = [
   1063         not isinstance(g.type, DisconnectedType) for g in output_grads
   1064     ]

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1387, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1384 for node in node_to_idx:
   1385     for idx in node_to_idx[node]:
-> 1387         term = access_term_cache(node)[idx]
   1389         if not isinstance(term, Variable):
   1390             raise TypeError(
   1391                 f"{node.op}.grad returned {type(term)}, expected"
   1392                 " Variable instance."
   1393             )

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1059, in _populate_grad_dict.<locals>.access_term_cache(node)
   1055 if node not in term_dict:
   1057     inputs = node.inputs
-> 1059     output_grads = [access_grad_cache(var) for var in node.outputs]
   1061     # list of bools indicating if each output is connected to the cost
   1062     outputs_connected = [
   1063         not isinstance(g.type, DisconnectedType) for g in output_grads
   1064     ]

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1059, in <listcomp>(.0)
   1055 if node not in term_dict:
   1057     inputs = node.inputs
-> 1059     output_grads = [access_grad_cache(var) for var in node.outputs]
   1061     # list of bools indicating if each output is connected to the cost
   1062     outputs_connected = [
   1063         not isinstance(g.type, DisconnectedType) for g in output_grads
   1064     ]

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1387, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1384 for node in node_to_idx:
   1385     for idx in node_to_idx[node]:
-> 1387         term = access_term_cache(node)[idx]
   1389         if not isinstance(term, Variable):
   1390             raise TypeError(
   1391                 f"{node.op}.grad returned {type(term)}, expected"
   1392                 " Variable instance."
   1393             )

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/gradient.py:1214, in _populate_grad_dict.<locals>.access_term_cache(node)
   1206         if o_shape != g_shape:
   1207             raise ValueError(
   1208                 "Got a gradient of shape "
   1209                 + str(o_shape)
   1210                 + " on an output of shape "
   1211                 + str(g_shape)
   1212             )
-> 1214 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1216 if input_grads is None:
   1217     raise TypeError(
   1218         f"{node.op}.grad returned NoneType, expected iterable."
   1219     )

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/scan/op.py:2304, in Scan.L_op(self, inputs, outs, dC_douts)
   2301 # Restrict the number of grad steps according to
   2302 # self.truncate_gradient
   2303 if self.truncate_gradient != -1:
-> 2304     grad_steps = minimum(grad_steps, self.truncate_gradient)
   2306 self_inputs = self.inputs
   2307 self_outputs = self.outputs

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/graph/op.py:294, in Op.__call__(self, *inputs, **kwargs)
    252 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    253 
    254 This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    291 
    292 """
    293 return_list = kwargs.pop("return_list", False)
--> 294 node = self.make_node(*inputs, **kwargs)
    296 if config.compute_test_value != "off":
    297     compute_test_value(node)

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/tensor/elemwise.py:462, in Elemwise.make_node(self, *inputs)
    456 def make_node(self, *inputs):
    457     """
    458     If the inputs have different number of dimensions, their shape
    459     is left-completed to the greatest number of dimensions with 1s
    460     using DimShuffle.
    461     """
--> 462     inputs = [as_tensor_variable(i) for i in inputs]
    463     out_dtypes, out_broadcastables, inputs = self.get_output_info(
    464         DimShuffle, *inputs
    465     )
    466     outputs = [
    467         TensorType(dtype=dtype, shape=broadcastable)()
    468         for dtype, broadcastable in zip(out_dtypes, out_broadcastables)
    469     ]

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/tensor/elemwise.py:462, in <listcomp>(.0)
    456 def make_node(self, *inputs):
    457     """
    458     If the inputs have different number of dimensions, their shape
    459     is left-completed to the greatest number of dimensions with 1s
    460     using DimShuffle.
    461     """
--> 462     inputs = [as_tensor_variable(i) for i in inputs]
    463     out_dtypes, out_broadcastables, inputs = self.get_output_info(
    464         DimShuffle, *inputs
    465     )
    466     outputs = [
    467         TensorType(dtype=dtype, shape=broadcastable)()
    468         for dtype, broadcastable in zip(out_dtypes, out_broadcastables)
    469     ]

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/tensor/__init__.py:42, in as_tensor_variable(x, name, ndim, **kwargs)
     10 def as_tensor_variable(
     11     x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
     12 ) -> "TensorVariable":
     13     """Convert `x` into an equivalent `TensorVariable`.
     14 
     15     This function can be used to turn ndarrays, numbers, `Scalar` instances,
   (...)
     40 
     41     """
---> 42     return _as_tensor_variable(x, name, ndim, **kwargs)

File ~/miniconda3/envs/notebooks/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/miniconda3/envs/notebooks/lib/python3.10/site-packages/aesara/tensor/basic.py:176, in _as_tensor_bool(x, name, ndim, **kwargs)
    174 @_as_tensor_variable.register(bool)
    175 def _as_tensor_bool(x, name, ndim, **kwargs):
--> 176     raise TypeError(
    177         "Cannot cast True or False as a tensor variable. Please use "
    178         "np.array(True) or np.array(False) if you need these constants. "
    179         "This error might be caused by using the == operator on "
    180         "Variables. v == w does not do what you think it does, "
    181         "use aesara.tensor.eq(v, w) instead."
    182     )

TypeError: Cannot cast True or False as a tensor variable. Please use np.array(True) or np.array(False) if you need these constants. This error might be caused by using the == operator on Variables. v == w does not do what you think it does, use aesara.tensor.eq(v, w) instead.

Are you running an old version of pymc/aesara by any chance? That sounds like a bug that was solved sometime ago.

I have tested locally and your snippet works for me with pymc 4.1.3

image

Seems likely to be the reason, let me update.

image

It is working after upgrading, thanks @ricardoV94

1 Like