The cholesky factorization of a diagonal matrix is just its elemwise square roots, so you can maybe get the speedup in the diagonal model by just doing pm.MvNormal.dist(mu=centroids[i], chol=pt.sqrt(covs[i]))
1 Like