# How to use MatrixNormal for modeling several matrices?

I would like to build the likelihood on several matrices simultaneously, i.e., my observed data (tensor) is of shape (Z, M, N)=(885, 50, 6).

``````K = 5
T = 50
P = 6
N = data_x.shape[0]

with pm.Model(coords={"component": np.arange(K), "obs_id": np.arange(N)}) as model:
# DP prior
pi = pm.StickBreakingWeights("pi", alpha=1, K=K-1)

# Priors for the mean matrices for each component
M = pm.Normal('M', mu=0, sigma=5, shape=(K, T, P))

# Priors for column covariance matrices
V_chol = [pm.LKJCholeskyCov('V_chol_%d' % i, eta=2, n=P, sd_dist=pm.HalfCauchy.dist(2.5))
for i in range(K)]
V = [pm.Deterministic('V_%d' % i, pt.dot(chol[0], chol[0].T))
for i, chol in enumerate(V_chol)]

# Priors for row covariance matrices
U_chol = [pm.LKJCholeskyCov('U_chol_%d' % i, eta=2, n=T, sd_dist=pm.HalfCauchy.dist(2.5))
for i in range(K)]
U = [pm.Deterministic('U_%d' % i, pt.dot(chol[0], chol[0].T))
for i, chol in enumerate(U_chol)]

# Mixture components
components = [pm.MatrixNormal.dist(mu=M[i], rowcov=U[i], colcov=V[i])
for i in range(K)]

x_obs = pm.Mixture('x_obs', w=pi, comp_dists=components, observed=data_x, dims="obs_id")

``````

However, it seems that MatrixNormal doesn’t support a tensor as the input.

``````---------------------------------------------------------------------------
ShapeError                                Traceback (most recent call last)
Cell In[6], line 29
25 # Mixture components
26 components = [pm.MatrixNormal.dist(mu=M[i], rowcov=U[i], colcov=V[i])
27               for i in range(K)]
---> 29 x_obs = pm.Mixture('x_obs', w=pi, comp_dists=components, observed=data_x, dims="obs_id")

File ~\anaconda3\envs\pymc5\Lib\site-packages\pymc\distributions\distribution.py:371, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
367         kwargs["shape"] = tuple(observed.shape)
369 rv_out = cls.dist(*args, **kwargs)
--> 371 rv_out = model.register_rv(
372     rv_out,
373     name,
374     observed,
375     total_size,
376     dims=dims,
377     transform=transform,
378     initval=initval,
379 )
381 # add in pretty-printing support
382 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)

File ~\anaconda3\envs\pymc5\Lib\site-packages\pymc\model\core.py:1281, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval)
1274         raise TypeError(
1275             "Variables that depend on other nodes cannot be used for observed data."
1276             f"The data variable was: {observed}"
1277         )
1279     # `rv_var` is potentially changed by `make_obs_var`,
1280     # for example into a new graph for imputation of missing data.
-> 1281     rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size)
1283 return rv_var

File ~\anaconda3\envs\pymc5\Lib\site-packages\pymc\model\core.py:1312, in Model.make_obs_var(self, rv_var, data, dims, transform, total_size)
1309 data = convert_observed_data(data).astype(rv_var.dtype)
1311 if data.ndim != rv_var.ndim:
-> 1312     raise ShapeError(
1313         "Dimensionality of data and RV don't match.", actual=data.ndim, expected=rv_var.ndim
1314     )