Can PyMC simplify MvNormal with diagonal covariance matrices

Also, it’s presumably faster in PyMC to just use univariate normals when there is an identity covariance matrix.

We rewrite / simplify / eliminate linalg Ops when the structure of matrices can be inferred (including identity), so ideally it should be the same.

Not just identity but diagonal. We should check if our rewrites can already simplify the density close to the univariate form

Neat. We haven’t thought to do that in Stan’s multivariate normal implementation. Some of the data flow analysis in model code is much harder for us even if we had thought about it.

How do you do it in general and in particular when the covariance matrix is an unknown and autodiff is involved? Do you do the quadratic test of all the off-diagonal elements or is there some structure left in the model you test? Testing the off diagonals should be faster than factoring the matrix (quadratic rather than cubic), but probably slower than calling the univariate depending on how you code autodiff (at least that’d be true for Stan).

Analysis is all done at compile time, not runtime.

If the covariance is constant like here we can do a one time eval during compilation. Otherwise if it’s built symbolic like eye() * x or zeros[diag_indices].set(x) ± elemwise operations that map 0->0, we infer it must be diagonal.

If it comes from a prior in a model, it’s usually an lkj prior, and we work with that form directly, if the user passes it to the chol argument instead of cov. If the user does materialize the cov with L@L.T, we can also notice and undo the process.

If the covariance comes from a linalg.inv(x) we use the logp in terms for inverse covariance and avoid the inversion.

A bit of this is still very much WIP. In the end it’s always better/safer to start from the optimized form (use univariate if there’s no meaningful covariance), pass the chol argument if it’s a cholesky factor, etc…

Still we want to handle the worst case scenarios just in case the user is just in an ideation phase or lacks the domain knowledge. The most obvious example is probably linalg.inv(x) @ b -> solve(x, b)

A nice thing I don’t think other graph autodiff libraries do is we run some simplification rewrites before building the gradient graph. That can yield nicer autodiff graphs than if we created it immediately and only then tried to rewrite it.

If your gradient is built dynamically by taping the forward pass (like stan-math) then that doesn’t matter, but that approach has other downsides like not reusing as much shared computation/memory between forward and backward.