Help in extending MatrixNormal to 3 Dimensions

I’m wondering what would be the appropriate way to extend MatrixNormal to handle three dimensions. Looking at the code, some of the extension seems straight-forward but other parts (such as the computation of trquaddist) do not.

For the quaddist would it just be as follows:

        quaddist = self.solve_lower(rowchol_cov, delta)

        quaddist = tt.nlinalg.matrix_dot(quaddist.T, quaddist)
        quaddist = self.solve_lower(colchol_cov, quaddist)
        quaddist = self.solve_upper(colchol_cov.T, quaddist)

// the following three lines added -
        quaddist = tt.nlinalg.matrix_dot(quaddist.T, quaddist)
        quaddist = self.solve_lower(planechol_cov, quaddist)
        quaddist = self.solve_upper(planechol_cov.T, quaddist)

        trquaddist = tt.nlinalg.trace(quaddist)

Any help here or suggestions in general on this would be appreciated.

I see now that the Kronecker Normal may be what I was thinking of:
https://docs.pymc.io/api/distributions/multivariate.html#pymc3.distributions.multivariate.KroneckerNormal

Some docs are provided here:
https://docs.pymc.io/notebooks/GP-Kron.html

I also found this article on the subject:
https://www.sciencedirect.com/science/article/pii/S0377042712003810