Batch matrix multiplication

This is not what happens.

w1_np.dot(w2_np).shape == (B, M, B, P)

What you want is matmul:

np.matmul(w1_np, w2_np).shape == (B, M, P)

Which is what you get when you use the @ operator in numpy and in PyMC (but only the latest version, there was a mistake before).

To be safe you can use pytensor.tensor.matmul(w1, w2)

2 Likes