Is there a way to do a batch matrix multiplication in pymc? The prototype operation in numpy

is:

```
B, M, N, P = 32, 100, 5, 2
w1_np = np.random.randn(B, M, N)
w2_np = np.random.randn(B, N, P)
data = w1_np.dot(w2_np) # shape: (B, M, P)
```

where I wanna carry out the B matrix multiplies (M, N), (N,P) → (M, P) in one shot and store them in

a 3D structure (B, M, P). My tentative pymc code is:

```
with pm.Model() as bmm:
w1 = pm.Normal('w1', 0, sigma=1, shape=(B, M, N))
w2 = pm.Normal('w2', 0, sigma=1, shape=(B, N, P))
#ym = pm.math.dot(w1, w2)
ym = w1 @ w2
out = pm.Normal('out', mu=ym)#, observed=data)
prior_data = pm.sample_prior_predictive()
```

but it does not work, while the version with `ym = pm.math.dot(w1, w2)`

returns a result with different dimensions.

Thanks,

Marco