Pytensor.scan IndexError

I’m using pytensor.scan to make values 0 if any of the three preceding values in a column are non-zero.

Here’s the function, and it’s output:

import pytensor as pt
import numpy as np

print("pytensor version", pt.__version__)

np.random.seed(42)
m, n = 10, 12
arr = np.random.choice([0, 1], p=[0.75, 0.25], size=m * n).reshape(m, n)

print("\ninput:")
print(arr)

def mask_prev_three(arr: pt.tensor.TensorLike, n: int) -> pt.tensor.TensorVariable:
    taps = -3, -2, -1
    initial = pt.tensor.zeros((3, n), "int8")
    masked, _ = pt.scan(
        lambda i0, im3, im2, im1: pt.tensor.switch(im3 | im2 | im1, 0, i0),
        sequences=arr,
        outputs_info=dict(taps=taps, initial=initial),
    )
    return masked

masked = mask_prev_three(pt.tensor.as_tensor(arr, dtype="int8"), n).eval()

print("\noutput:")
print(masked)
pytensor version 2.10.1

input:
[[0 1 0 0 0 0 0 1 0 0 0 1]
 [1 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 1 1 1]
 [0 0 0 0 0 0 0 1 0 0 0 0]
 [0 0 1 1 1 1 0 1 0 0 0 0]
 [0 0 1 0 0 0 0 1 0 1 1 0]
 [0 1 0 0 1 0 0 0 1 0 0 0]
 [0 0 0 0 1 0 0 0 1 0 1 0]
 [0 0 0 0 0 0 0 0 1 0 0 1]
 [0 0 0 0 1 1 0 1 1 0 1 0]]

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: save_mem_new_scan
ERROR (pytensor.graph.rewriting.basic): node: for{cpu,scan_fn}(TensorConstant{10}, TensorConstant{[[0 1 0 0 .. 1 0 1 0]]}, IncSubtensor{Set;:int64:}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/Users/pattinson/.virtualenvs/mfsera-env/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py", line 1918, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "/Users/pattinson/.virtualenvs/mfsera-env/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py", line 1078, in transform
    return self.fn(fgraph, node)
  File "/Users/pattinson/.virtualenvs/mfsera-env/lib/python3.10/site-packages/pytensor/scan/rewriting.py", line 1431, in save_mem_new_scan
    nw_input = expand_empty(_nw_input, tmp_idx)
  File "/Users/pattinson/.virtualenvs/mfsera-env/lib/python3.10/site-packages/pytensor/scan/utils.py", line 234, in expand_empty
    new_shape = [size + shapes[0]] + shapes[1:]
IndexError: list index out of range


output:
[[0 1 0 0 0 0 0 1 0 0 0 1]
 [1 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 1 0]
 [0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 1 1 1 0 1 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 1 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 0 0 0 0 1]
 [0 0 0 0 1 1 0 1 0 0 0 0]]

The output is as it should be, but the errors that are generated are a bit worrying; I think this might be impacting sampling.

I’ve experimented with the variables being passed to scan and I’m pretty sure everything is correct for the desired functionality (which is occurring).

Any help much appreciated
Thanks

David

The error shouldn’t affect sampling, it happens when a it tries to apply a rewrite that is not important here. The rewrite fails (shouldn’t error like that ofc), but since it doesn’t matter here it’s fine in the end.

What version of PyTensor are you using? I remember fixing something related to zeros and this rewrite. Can you try passing a constant array of 3 zeros pt.as_tensor(np.zeros(3, dtype=int8)) and see if the error goes away?

Passing np.zeros(3, dtype="int8") generates the same error.

BTW I checked whether aesara had the same issue, and it did. See here: aesara.scan IndexError · Issue #1499 · aesara-devs/aesara · GitHub

Shall I open an issue on the pytensor repo?

1 Like

Yes, much appreciated

Github issue is here

1 Like