Is there a way to do a batch matrix multiplication in pymc? The prototype operation in numpy
is:
B, M, N, P = 32, 100, 5, 2
w1_np = np.random.randn(B, M, N)
w2_np = np.random.randn(B, N, P)
data = w1_np.dot(w2_np) # shape: (B, M, P)
where I wanna carry out the B matrix multiplies (M, N), (N,P) → (M, P) in one shot and store them in
a 3D structure (B, M, P). My tentative pymc code is:
with pm.Model() as bmm:
w1 = pm.Normal('w1', 0, sigma=1, shape=(B, M, N))
w2 = pm.Normal('w2', 0, sigma=1, shape=(B, N, P))
#ym = pm.math.dot(w1, w2)
ym = w1 @ w2
out = pm.Normal('out', mu=ym)#, observed=data)
prior_data = pm.sample_prior_predictive()
but it does not work, while the version with ym = pm.math.dot(w1, w2)
returns a result with different dimensions.
Thanks,
Marco