For some time I have tried to build multivariate models based on WishartBartlett or LKJ priors, first in PyMC3 and more recently in PyMC5, and it has never really worked. It is always so slow that it is in practice completely unrealistic.
The only decent results I got were from marginalising the covariance matrix, therefore writing some StudentT-based custom distribution. Then it is very unstable and quite a bit more expensive than uncorrelated models, but at least doable.
I only tried on a small local machine with 96 cores. But a priori I shouldn’t have to use my Tier-0 supercomputing budget for such a simple problem. So I am wondering whether I missed something.
Here I build some minimalist example, which I would have expected to run fine on any laptop:
Here my synthetic data:
ncfg = 1009
platlength = 3
true_cov = np.array([[.8, .3, .1],
[.3, .4, .1],
[.1, .1, .9]])
true_mu = np.array([0, 0, 0])
rng = np.random.RandomState(0)
Yplateau = rng.multivariate_normal(mean=true_mu,
cov=true_cov,
size=ncfg)
print(Yplateau.shape)
Now a simple marginalised model, if I am aiming for a determination of mu but only view cov as a nuisance parameter:
marginalised_model = pm.Model()
with marginalised_model:
mu = pm.Flat("mu",initval=Yplateau.mean(axis=0),shape=platlength)
nu = pm.Exponential("nu",lam=1.0/np.sqrt(ncfg))
nup = nu + ncfg
scalep = ncfg * np.eye(platlength)
Y_obs = pm.MvStudentT("Y_obs",mu=mu,Sigma=scalep/(nup-platlength+1),nu=nup-platlength+1,
observed=Yplateau)
This samples OK, in less than one minute:
Now if I try to do basically the same but with an explicit sampling of the covariance:
wbtrivial_model = pm.Model()
with wbtrivial_model:
mu = pm.Flat("mu",initval=Yplateau.mean(axis=0),shape=platlength)
nu = pm.Exponential("nu",lam=1.0/np.sqrt(ncfg))
S = np.eye(platlength)
cov = pm.WishartBartlett('chol_cov', S, nu, is_cholesky=False, return_cholesky=False,
initval=plateaunaivecov)
Y_obs = pm.MvNormal("Y_obs",mu=mu,cov=cov,
observed=Yplateau)
then the sampler runs for hours. Just to infer a 3x3 covariance matrix!
However in this case I can get the MAP quickly, but this is essentially just the sample covariance, so not very interesting.
Using LKJ priors is just as slow, but on top of that my MAP seems to make no sense (even with eta=1 the MAP covariance is much more diagonal than the sample covariance)
lkjtrivial_model = pm.Model()
with lkjtrivial_model:
mu = pm.Flat("mu",initval=Yplateau.mean(axis=0),shape=platlength)
sd_dist = pm.Exponential.dist(np.eye(platlength), size=platlength)
chol, corr, sigmas = pm.LKJCholeskyCov(
'chol_cov', eta=1.0, n=platlength, sd_dist=sd_dist
)
Y_obs = pm.MvNormal("Y_obs",mu=mu,chol=chol,
observed=Yplateau)
Is there a way to improve either the WB or the LKJ model in a way which would give me a result in a few minutes for this extremely simple synthetic data?
Or is it just that MvNormal likelihoods are always at least 100-1000x more expensive than Normal likelihood and I have to live with that?
In real life my models are of course more complicated: mu is built from a bunch of other random variables.
PS:
There’s some new-ish PyMC feature I hadn’t tried yet which is the possibility of using sample_blackjax_nuts() or sample_numpyro_nuts(). I just tried that on the minimalist example and that’s a game changer. I obtain decent results in a few seconds (despite a lot of divergences). So the conclusion should be never to use pm.sample() any more??
On my real-life data, Numpyro is still having a hard time, but I will keep exploring this direction.