Aesara outer product more than two inputs

Maybe not the best possible solution, but what I came up with was to store the lengths of Y, pad everything into a matrix, and feed it into a scan. Something like:

def pad_Y(Y):
    '''Transform ragged list Y of length n into a matrix of shape n x max(len(Y))'''
    lengths = [len(y) for y in Y]
    max_len = max(lengths)

    Y_pad = [np.r_[y, np.full(max_len - len(y), 0)] for y in Y]

    Y_out = np.stack(Y_pad)
    
    return Y_out

def outer_step(l, r, l_x, r_x, l_shape, r_shape):
    '''Outer product between l and r using broadcast multiplication'''
    l = l[:l_shape] * l_x
    r = r[:r_shape] * r_x
    
    outer_dims = tuple([slice(None)]*l.ndim + [None]*r.ndim)
    return (l[outer_dims] * r)

Y = at.matrix()
Y_lens = at.ivector()
x = at.vector()

output, updates = aesara.scan(outer_step,
                       sequences=(dict(input=Y, taps=[0, -1]),
                                  dict(input=x, taps=[0, -1]),
                                  dict(input=Y_lens, taps=[0, -1])))

It works for the D = 2 case (but output is transposed because scan feeds row-wise), but fails on higher cases because I wasn’t sure how you wanted to define the outer product between a list of vectors (the example numpy code isn’t valid, it only takes two inputs). I thought the most general case would be something like: np.einsum("i, j, k, .... -> ijk...", a, b, c, d...)? I couldn’t quite work out how to get an accumulator into the scan, since the dimension of the tensor would grow at iteration, and scan doesn’t like that. Maybe use the Y_lens to make the final output object (at.zeros(Y_lens)?) then at.set_subtensor inside the scan to fill up the dimensions as you go?

If you know D ahead of time you could avoid the scan by just declaring D at.vectors and just looping over them, but since D is unknown I think you might have to scan?

2 Likes