Batch matrix multiplication

Is there a way to do a batch matrix multiplication in pymc? The prototype operation in numpy
is:

B, M, N, P = 32, 100, 5, 2
w1_np = np.random.randn(B, M, N)
w2_np = np.random.randn(B, N, P)
data = w1_np.dot(w2_np) # shape: (B, M, P)

where I wanna carry out the B matrix multiplies (M, N), (N,P) → (M, P) in one shot and store them in
a 3D structure (B, M, P). My tentative pymc code is:

with pm.Model() as bmm:

    w1 = pm.Normal('w1', 0, sigma=1, shape=(B, M, N))
    w2 = pm.Normal('w2', 0, sigma=1, shape=(B, N, P))

    #ym = pm.math.dot(w1, w2)
    ym = w1 @ w2
    out = pm.Normal('out', mu=ym)#, observed=data)
    prior_data = pm.sample_prior_predictive()

but it does not work, while the version with ym = pm.math.dot(w1, w2) returns a result with different dimensions.

Thanks,
Marco

This is not what happens.

w1_np.dot(w2_np).shape == (B, M, B, P)

What you want is matmul:

np.matmul(w1_np, w2_np).shape == (B, M, P)

Which is what you get when you use the @ operator in numpy and in PyMC (but only the latest version, there was a mistake before).

To be safe you can use pytensor.tensor.matmul(w1, w2)

2 Likes

Thanks Ricardo,

Thanks for the reply. Updating to version 5.9, the @ operator in pymc works. Previously I had issues in 5.7.2

This is not what happens.

w1_np.dot(w2_np).shape == (B, M, B, P)

Right, it was a leftover of previous attempts I meant @ here.