Is there a way to do a batch matrix multiplication in pymc? The prototype operation in numpy
B, M, N, P = 32, 100, 5, 2 w1_np = np.random.randn(B, M, N) w2_np = np.random.randn(B, N, P) data = w1_np.dot(w2_np) # shape: (B, M, P)
where I wanna carry out the B matrix multiplies (M, N), (N,P) → (M, P) in one shot and store them in
a 3D structure (B, M, P). My tentative pymc code is:
with pm.Model() as bmm: w1 = pm.Normal('w1', 0, sigma=1, shape=(B, M, N)) w2 = pm.Normal('w2', 0, sigma=1, shape=(B, N, P)) #ym = pm.math.dot(w1, w2) ym = w1 @ w2 out = pm.Normal('out', mu=ym)#, observed=data) prior_data = pm.sample_prior_predictive()
but it does not work, while the version with
ym = pm.math.dot(w1, w2) returns a result with different dimensions.