Is multivariate inference always so slow?

For some time I have tried to build multivariate models based on WishartBartlett or LKJ priors, first in PyMC3 and more recently in PyMC5, and it has never really worked. It is always so slow that it is in practice completely unrealistic.
The only decent results I got were from marginalising the covariance matrix, therefore writing some StudentT-based custom distribution. Then it is very unstable and quite a bit more expensive than uncorrelated models, but at least doable.
I only tried on a small local machine with 96 cores. But a priori I shouldn’t have to use my Tier-0 supercomputing budget for such a simple problem. So I am wondering whether I missed something.

Here I build some minimalist example, which I would have expected to run fine on any laptop:

Here my synthetic data:

ncfg = 1009
platlength = 3
true_cov = np.array([[.8, .3, .1],
                     [.3, .4, .1],
                     [.1, .1, .9]])
true_mu = np.array([0, 0, 0])
rng = np.random.RandomState(0)
Yplateau = rng.multivariate_normal(mean=true_mu,

Now a simple marginalised model, if I am aiming for a determination of mu but only view cov as a nuisance parameter:

marginalised_model = pm.Model()
with marginalised_model:
    mu = pm.Flat("mu",initval=Yplateau.mean(axis=0),shape=platlength)
    nu = pm.Exponential("nu",lam=1.0/np.sqrt(ncfg))
    nup = nu + ncfg
    scalep = ncfg * np.eye(platlength)
    Y_obs = pm.MvStudentT("Y_obs",mu=mu,Sigma=scalep/(nup-platlength+1),nu=nup-platlength+1,

This samples OK, in less than one minute:

Now if I try to do basically the same but with an explicit sampling of the covariance:

wbtrivial_model = pm.Model()
with wbtrivial_model:
   mu = pm.Flat("mu",initval=Yplateau.mean(axis=0),shape=platlength)
    nu = pm.Exponential("nu",lam=1.0/np.sqrt(ncfg)) 
    S = np.eye(platlength)
    cov = pm.WishartBartlett('chol_cov', S, nu, is_cholesky=False, return_cholesky=False, 
    Y_obs = pm.MvNormal("Y_obs",mu=mu,cov=cov,

then the sampler runs for hours. Just to infer a 3x3 covariance matrix!
However in this case I can get the MAP quickly, but this is essentially just the sample covariance, so not very interesting.

Using LKJ priors is just as slow, but on top of that my MAP seems to make no sense (even with eta=1 the MAP covariance is much more diagonal than the sample covariance)

lkjtrivial_model = pm.Model()
with lkjtrivial_model:
    mu = pm.Flat("mu",initval=Yplateau.mean(axis=0),shape=platlength)
    sd_dist = pm.Exponential.dist(np.eye(platlength), size=platlength) 
    chol, corr, sigmas = pm.LKJCholeskyCov(
        'chol_cov', eta=1.0, n=platlength, sd_dist=sd_dist
    Y_obs = pm.MvNormal("Y_obs",mu=mu,chol=chol,

Is there a way to improve either the WB or the LKJ model in a way which would give me a result in a few minutes for this extremely simple synthetic data?
Or is it just that MvNormal likelihoods are always at least 100-1000x more expensive than Normal likelihood and I have to live with that?

In real life my models are of course more complicated: mu is built from a bunch of other random variables.

There’s some new-ish PyMC feature I hadn’t tried yet which is the possibility of using sample_blackjax_nuts() or sample_numpyro_nuts(). I just tried that on the minimalist example and that’s a game changer. I obtain decent results in a few seconds (despite a lot of divergences). So the conclusion should be never to use pm.sample() any more??
On my real-life data, Numpyro is still having a hard time, but I will keep exploring this direction.

WB is pretty terrible to sample, should never be used.

LKJ should be fine. There is a known source of slowdown that should be fixed in add lower triangular tags to allow chol rewrite by dehorsley · Pull Request #6736 · pymc-devs/pymc · GitHub

You can try to run fewer chains to see if it’s a CPU saturation issue. Some of the operations are natively multi-threaded and could conflict when sampling multiple chains in parallel.

Not all models are compatible with JAX sampler, but all are with the PyMC sampler. Otherwise using JAX is getting more and more popular. You can also try nutpie, for cpu based sampling.

Mmmh. I do indeed get a huge speedup in my minimalist example when restricting the sampling to one single chain.

That’s not something I would have expected, since anyway the number of chains (default 4) was already much smaller than the number cores (96).

The Wishart-Bartlett does improve as well. And just like with Jax the performances of LKJ and WB look once again very similar.

Given that the WB is much easier to interpret theoretically, and that the priors are much easier to choose, I still haven’t seen any convincing case for using LKJ.

If I want several chains, then cores=1 works as well:

That’s probably what I am going to do for now on.

Would you have in mind any other way of fixing that which doesn’t involve keeping 95 cores idle?

You might want to play with OMP_NUM_THREADS and the like. This may be helpful:

Link to new FAQ entry: Frequently Asked Questions - #19 by ricardoV94

1 Like

Also, as mentioned before the fix in here will also speedup models using the LKJ + MvNormal:

This fixed BUG: MvNormal logp recomputes Cholesky factorization · Issue #6717 · pymc-devs/pymc · GitHub

It will be included in the next release

Thanks for all your answers.
For sure I’ll stay tuned and I’ll try again as soon as the release gets on conda-forge

It’s there if you care to check it out @julien :wink:

1 Like

On the minimal model with a single core, there seems to be some marginal improvement. But my workstation is currently full of other jobs so it’s not the best setting for these tests.
Pre-update (5.5):

Post-update (5.6):

More importantly, the parallelisation issue seems to have disappeared: I can now sample 4 chains (and probably more) for the price of a single one:

I will try to see tomorrow whether this changes something for my real-life models :slight_smile:

1 Like

What would be considered a large dataset for a MV problem?

I’m using the BVAR example with 5 time series/equations and 20 lags with 10,000 observations.

Using numpyro on a modest laptop GPU (RTX A2000), it’s currently taking around 15 hours to run using 5.6.0. Edit: I’m actually on 5.5.0 (will try the new version today)

Is this performance typical? If so any ideas on how to speed it up?

That sounds quite large. But maybe better to open a new thread and share how you’re implementing it? Priors can matter a lot on BVAR models.

1 Like