How to compare shape in PyTensor?

PyTensor graph can’t evaluate straight forward shape comparison right now (causing by this code), what is the workaround for now?

Code

import pytensor.tensor as pt

t = pt.as_tensor_variable([10])
if t.shape[0] > 10:
    print("Hello world!")
else:
    print("Hi!")

Output

TypeError: Variables do not support boolean operations.

For control flow, you want to use pt.switch. The comparison you do is otherwise fine:

x = pt.dmatrix('x')
y = pt.switch(x.shape[0] > 10, 1, 0)
y.eval({x:np.zeros((13, 10))})
>>> Out: 1
1 Like