I’m using pytensor.scan
to make values 0 if any of the three preceding values in a column are non-zero.
Here’s the function, and it’s output:
import pytensor as pt
import numpy as np
print("pytensor version", pt.__version__)
np.random.seed(42)
m, n = 10, 12
arr = np.random.choice([0, 1], p=[0.75, 0.25], size=m * n).reshape(m, n)
print("\ninput:")
print(arr)
def mask_prev_three(arr: pt.tensor.TensorLike, n: int) -> pt.tensor.TensorVariable:
taps = -3, -2, -1
initial = pt.tensor.zeros((3, n), "int8")
masked, _ = pt.scan(
lambda i0, im3, im2, im1: pt.tensor.switch(im3 | im2 | im1, 0, i0),
sequences=arr,
outputs_info=dict(taps=taps, initial=initial),
)
return masked
masked = mask_prev_three(pt.tensor.as_tensor(arr, dtype="int8"), n).eval()
print("\noutput:")
print(masked)
pytensor version 2.10.1
input:
[[0 1 0 0 0 0 0 1 0 0 0 1]
[1 0 0 0 0 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 0 1 1 1]
[0 0 0 0 0 0 0 1 0 0 0 0]
[0 0 1 1 1 1 0 1 0 0 0 0]
[0 0 1 0 0 0 0 1 0 1 1 0]
[0 1 0 0 1 0 0 0 1 0 0 0]
[0 0 0 0 1 0 0 0 1 0 1 0]
[0 0 0 0 0 0 0 0 1 0 0 1]
[0 0 0 0 1 1 0 1 1 0 1 0]]
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: save_mem_new_scan
ERROR (pytensor.graph.rewriting.basic): node: for{cpu,scan_fn}(TensorConstant{10}, TensorConstant{[[0 1 0 0 .. 1 0 1 0]]}, IncSubtensor{Set;:int64:}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
File "/Users/pattinson/.virtualenvs/mfsera-env/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py", line 1918, in process_node
replacements = node_rewriter.transform(fgraph, node)
File "/Users/pattinson/.virtualenvs/mfsera-env/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py", line 1078, in transform
return self.fn(fgraph, node)
File "/Users/pattinson/.virtualenvs/mfsera-env/lib/python3.10/site-packages/pytensor/scan/rewriting.py", line 1431, in save_mem_new_scan
nw_input = expand_empty(_nw_input, tmp_idx)
File "/Users/pattinson/.virtualenvs/mfsera-env/lib/python3.10/site-packages/pytensor/scan/utils.py", line 234, in expand_empty
new_shape = [size + shapes[0]] + shapes[1:]
IndexError: list index out of range
output:
[[0 1 0 0 0 0 0 1 0 0 0 1]
[1 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 1 1 0]
[0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 1 1 1 1 0 1 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 1 0 0 0]
[0 0 0 0 0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 0 0 0 0 1]
[0 0 0 0 1 1 0 1 0 0 0 0]]
The output is as it should be, but the errors that are generated are a bit worrying; I think this might be impacting sampling.
I’ve experimented with the variables being passed to scan and I’m pretty sure everything is correct for the desired functionality (which is occurring).
Any help much appreciated
Thanks
David