How to use MatrixNormal for modeling several matrices?

I would like to build the likelihood on several matrices simultaneously, i.e., my observed data (tensor) is of shape (Z, M, N)=(885, 50, 6).

K = 5
T = 50
P = 6
N = data_x.shape[0]

with pm.Model(coords={"component": np.arange(K), "obs_id": np.arange(N)}) as model:
    # DP prior
    pi = pm.StickBreakingWeights("pi", alpha=1, K=K-1)
    
    # Priors for the mean matrices for each component
    M = pm.Normal('M', mu=0, sigma=5, shape=(K, T, P))

    # Priors for column covariance matrices
    V_chol = [pm.LKJCholeskyCov('V_chol_%d' % i, eta=2, n=P, sd_dist=pm.HalfCauchy.dist(2.5))
              for i in range(K)]
    V = [pm.Deterministic('V_%d' % i, pt.dot(chol[0], chol[0].T))
         for i, chol in enumerate(V_chol)]
    
    # Priors for row covariance matrices
    U_chol = [pm.LKJCholeskyCov('U_chol_%d' % i, eta=2, n=T, sd_dist=pm.HalfCauchy.dist(2.5))
              for i in range(K)]
    U = [pm.Deterministic('U_%d' % i, pt.dot(chol[0], chol[0].T))
         for i, chol in enumerate(U_chol)]

    # Mixture components
    components = [pm.MatrixNormal.dist(mu=M[i], rowcov=U[i], colcov=V[i])
                  for i in range(K)]
    
    x_obs = pm.Mixture('x_obs', w=pi, comp_dists=components, observed=data_x, dims="obs_id")
    
    tr = pm.sample(2000, tune=3000, init='jitter+adapt_diag_grad', chains=1, cores=8,
                       discard_tuned_samples=True, return_inferencedata=True)

However, it seems that MatrixNormal doesn’t support a tensor as the input.

---------------------------------------------------------------------------
ShapeError                                Traceback (most recent call last)
Cell In[6], line 29
     25 # Mixture components
     26 components = [pm.MatrixNormal.dist(mu=M[i], rowcov=U[i], colcov=V[i])
     27               for i in range(K)]
---> 29 x_obs = pm.Mixture('x_obs', w=pi, comp_dists=components, observed=data_x, dims="obs_id")
     31 tr = pm.sample(2000, tune=3000, init='jitter+adapt_diag_grad', chains=1, cores=8,
     32                    discard_tuned_samples=True, return_inferencedata=True)

File ~\anaconda3\envs\pymc5\Lib\site-packages\pymc\distributions\distribution.py:371, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    367         kwargs["shape"] = tuple(observed.shape)
    369 rv_out = cls.dist(*args, **kwargs)
--> 371 rv_out = model.register_rv(
    372     rv_out,
    373     name,
    374     observed,
    375     total_size,
    376     dims=dims,
    377     transform=transform,
    378     initval=initval,
    379 )
    381 # add in pretty-printing support
    382 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)

File ~\anaconda3\envs\pymc5\Lib\site-packages\pymc\model\core.py:1281, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval)
   1274         raise TypeError(
   1275             "Variables that depend on other nodes cannot be used for observed data."
   1276             f"The data variable was: {observed}"
   1277         )
   1279     # `rv_var` is potentially changed by `make_obs_var`,
   1280     # for example into a new graph for imputation of missing data.
-> 1281     rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size)
   1283 return rv_var

File ~\anaconda3\envs\pymc5\Lib\site-packages\pymc\model\core.py:1312, in Model.make_obs_var(self, rv_var, data, dims, transform, total_size)
   1309 data = convert_observed_data(data).astype(rv_var.dtype)
   1311 if data.ndim != rv_var.ndim:
-> 1312     raise ShapeError(
   1313         "Dimensionality of data and RV don't match.", actual=data.ndim, expected=rv_var.ndim
   1314     )
   1316 mask = getattr(data, "mask", None)
   1317 if mask is not None:

ShapeError: Dimensionality of data and RV don't match. (actual 3 != expected 2)

And I also found that the example provided by the official document also used only one matrix served as the observed data: pymc.MatrixNormal — PyMC dev documentation

Could anyone please provide any suggestion to solve this problem? Thanks!

Your dims don’t match the shape of your observed data. You should have as many dims as data dimensions.