Aesara outer product more than two inputs

I’m trying to take the outer product over a list of vectors in a ragged array. The ragged array is a list of length D of aesara 1d tensors of different lengths. It and it’s vectors are fixed, it has no unknown parameters and it’s passed in at the start. I’ll need to take gradients with respect to a different vector x,

Here is a simple 2D example in numpy:

import numpy as np

Y = [
    np.linspace(-1, 1, 3),
    np.linspace(2, 3, 4)
x = np.array([4, 5])

    x[0] * Y[0],
    x[1] * Y[1]

What’s the best way to do this in Aesara for arbitrary D (D will be small though, usually less than 6)? For instance with D = 6,

    x[0] * Y[0],
    x[1] * Y[1],
    x[5] * Y[5],

but without the explicit indexing of course.

Maybe not the best possible solution, but what I came up with was to store the lengths of Y, pad everything into a matrix, and feed it into a scan. Something like:

def pad_Y(Y):
    '''Transform ragged list Y of length n into a matrix of shape n x max(len(Y))'''
    lengths = [len(y) for y in Y]
    max_len = max(lengths)

    Y_pad = [np.r_[y, np.full(max_len - len(y), 0)] for y in Y]

    Y_out = np.stack(Y_pad)
    return Y_out

def outer_step(l, r, l_x, r_x, l_shape, r_shape):
    '''Outer product between l and r using broadcast multiplication'''
    l = l[:l_shape] * l_x
    r = r[:r_shape] * r_x
    outer_dims = tuple([slice(None)]*l.ndim + [None]*r.ndim)
    return (l[outer_dims] * r)

Y = at.matrix()
Y_lens = at.ivector()
x = at.vector()

output, updates = aesara.scan(outer_step,
                       sequences=(dict(input=Y, taps=[0, -1]),
                                  dict(input=x, taps=[0, -1]),
                                  dict(input=Y_lens, taps=[0, -1])))

It works for the D = 2 case (but output is transposed because scan feeds row-wise), but fails on higher cases because I wasn’t sure how you wanted to define the outer product between a list of vectors (the example numpy code isn’t valid, it only takes two inputs). I thought the most general case would be something like: np.einsum("i, j, k, .... -> ijk...", a, b, c, d...)? I couldn’t quite work out how to get an accumulator into the scan, since the dimension of the tensor would grow at iteration, and scan doesn’t like that. Maybe use the Y_lens to make the final output object (at.zeros(Y_lens)?) then at.set_subtensor inside the scan to fill up the dimensions as you go?

If you know D ahead of time you could avoid the scan by just declaring D at.vectors and just looping over them, but since D is unknown I think you might have to scan?