Hi,
I am writing a model where I need to sample N
different MVNormal variables, which all have the same shape, but are each parameterized by a different vector of means and covariance matrix.
My code right now looks something like this:
for n in range(N):
pm.MvNormal(f"x_{n}", mu = my_means[n], cov=my_covs[n], shape=q)
Where my_means
has shape (N, q)
and my_covs
has shape (N, q, q)
.
However, this is quite slow and also makes the graphviz for the model hard to read. Is there a vectorized way to declare these variables?
Thank you in advance for your help!
How big are we talking for N and q? A bunch of independent MvN distributions is equivalent to a single MvN, if you stack up the means and concatenate the covariance matrices block-diagonally. It might get slow if N * q is too huge, though.
Thanks!
q is relatively small (< 5), but N is fairly big (several thousand observations).
Do you think that is probably too big? If not, is there an efficient way to convert my (N,q,q)
tensor into the block-diagonal format?
I wrote this code for doing it a while ago, but maybe kroneker products is a better choice, like this:
# Make a set of (N, N) matrices of all zeros with a 1 in the (i,i)th position
indicators = [np.zeros((5, 5)) for _ in range(5)]
for i in range(5):
indicators[i][i, i] = 1
with pm.Model() as mod:
L = pm.Normal('L', size=(5, 3, 3))
# This is just to make a block of positive semi-definite matrices, you can ignore it
# pt.einsum when?
covs, _ = pytensor.scan(lambda x: x @ x.T, sequences=[L])
# Here's where the actual block diagonal matrix get made
big_cov2 = pt.stack([pt.linalg.kron(indicators[i], covs[i]) for i in range(5)]).sum(axis=0)
obs = pm.MvNormal('obs', mu=0, cov=big_cov)
No idea which is faster, but this way is certainly less code, and avoids a scan.
It samples without error (the kron
one)
edit:
I think I misunderstood @ricardoV94’s question; the model with a single MvN and batch dimension doesn’t sample and runs into the linked issue
Thanks so much! I’ll give this a try and see if it works for my case. 
If anyone stumbles upon this thread, the MvNormal now accepts arbitrary batched params.
3 Likes