Multi-Variate Normal in pymc3 vs. R

I have almost finished Richard McElreath’s online course called “Statistical Rethinking” which is pretty good for Bayes beginners like myself. (You can find it by searching youtube for the playlist code PLDcUM9US4XdNM4Edgs7weiyIguLSToZRI which I cannot post a direct link to here because of the forum’s spam-filter.)

The latter part of the course uses a lot of multi-variate normal distributions, which are fairly simple and elegant to make in R. See for example the homework examples for that course:

But MV Normal distributions seem to be really complicated in pymc3. See for example the translation of McElreath’s book from R into pymc3:

I have searched this forum and other blogs etc., and this seems to be the correct way of using MV Normal distributions in pymc3. But it is an awful lot of code-lines that I frankly don’t understand, because I’m not familiar with the underlying math and algorithms to improve efficiency and numerical stability or whatever.

Since this is fairly simple in R, I wonder why it has to be so complicated in pymc3? MV Normal distributions are quite common, so is this something you might consider simplifying in pymc3 in the future?

Thanks!

PS: Please disable the forum’s spam-filters, it is really annoying that it rejects posts with more than 2 links etc.

Hi,
As a matter of fact, the API for MvNormal is going to be much more simple in the coming PyMC 3.9.0 :tada:
I’m in the middle of porting chapter 14 of Rethinking 2 to PyMC3 to show this out – haven’t had time to finalize the PR yet, but it’s coming :wink:

Basically, before 3.9.0, parametrizing the MvNormal was:

with pm.Model() as m_13_1:
    sd_dist = pm.HalfCauchy.dist(beta=2)
    packed_chol = pm.LKJCholeskyCov('chol_cov', eta=2, n=2, sd_dist=sd_dist)
    
    # compute the covariance matrix
    chol = pm.expand_packed_triangular(2, packed_chol, lower=True)
    cov = pm.math.dot(chol, chol.T)
    
    # Extract the standard deviations and rho
    sigma_ab = pm.Deterministic('sigma_cafe', tt.sqrt(tt.diag(cov)))
    corr = tt.diag(sigma_ab**-1).dot(cov.dot(tt.diag(sigma_ab**-1)))
    r = pm.Deterministic('Rho', corr[np.triu_indices(2, k=1)])
    
    ab = pm.Normal('ab', mu=0, sd=10, shape=2)  # prior for average intercept and slope
    ab_cafe = pm.MvNormal('ab_cafe', mu=ab, chol=chol, shape=(N_cafes, 2)) # Population of varying effects

And in >= 3.9.0 the same model will be:

with pm.Model() as m14_1:
    # extracts expanded Cholesky, stds and matix of correlations automatically
    chol, Rho_, sigma_cafe = pm.LKJCholeskyCov('chol_cov', n=2, eta=2, sd_dist=pm.Exponential.dist(1.0), compute_corr=True)
    
    a = pm.Normal('a', mu=5., sd=2.)  # prior for average intercept
    b = pm.Normal('b', mu=-1., sd=0.5)  # prior for average slope
    ab_cafe = pm.MvNormal('ab_cafe', mu=tt.stack([a, b]), chol=chol, shape=(N_cafes, 2)) # population of varying effects
2 Likes

Thank you! :slight_smile:

Yes, that is much, much better! I’m a beginner in all this and it is confusing enough to try and understand multi-level Bayesian models. And then I also have to worry about low-level details about Cholesky decompositions, matrix unpacking, etc. which makes it really difficult to understand what is going on. So I greatly appreciate that you are simplifying the API!

I don’t mean to rush you, but do you have any idea when this will be released? Are we talking 1 week or 3 months?

Thanks!

2 Likes

Glad you find it much simpler – that was exactly the intent! I actually had the idea of simplifying the API for MvNormal when I myself went through Rethinking 1st edition :laughing:

I wouldn’t say that Cholesky factors and other algebraic tricks are “low-level details” though – they are actually pretty important to sample efficiently and, in the end, get reliable posterior distributions. I like McElreath quote that “the sampler is part of your model”.
That being said, these tricks are definitely overwhelming when starting, so it’s hard to find the sweet spot between making people aware that they shouldn’t trust the samplers blindly while making the API simple enough to use :exploding_head:

Regarding the 3.9.0 release, the milestone was finished yesterday, so we’re aiming for a release next week :champagne: I’ll post back here once this is done and when chapter 14 of Rethinking 2 is merged in the resources repo :wink:

1 Like

As promised, I’m back to announce that 3.9.1 is officially released :champagne:

1 Like

Thanks!

I have finally had time to look at this again. It appears that the tutorial Notebook still uses the old syntax with pm.expand_packed_triangular etc.

Please consider updating the tutorial Notebook to the new and simpler syntax.

I don’t know if you also maintain the Notebooks for ‘Statistical Rethinking’, but it could also do with an update to the improved syntax, if you have the time.

Thanks!

PS: The spam filter on this forum is REALLY annoying! I cannot post links to these Notebooks. Please disable the spam-filter.

The webiste is indeed not up-to-date yet (we are in the middle of revamping tutorial NBs, so we didn’t compile the docs, but we should do it soon). The tutorial NB is already updated though, you can read it on github :wink:
Regarding the SR2 NBs, they already use the new syntax

Thanks for the quick reply!

I was viewing the github repo named ‘aloctavodia’ that I linked to in my first post in this thread. I was not aware that the ‘Statistical Rethinking’ Notebooks have moved to another github repo.

Please review your very strict spam-filter settings in this forum, because it makes it difficult to post good and precise questions when the spam-filter blocks internet links so easily.

I don’t know if you are interested in feedback about the LKJ tutorial, but here goes: It requires that the reader already knows alot about all this. This is really not a tutorial for beginners. I have watched McElreath’s youtube lectures and I still consider myself a beginner who sort-of understands the main ideas. But I found your LKJ tutorial really confusing. There also seems to be remnants of the old syntax, e.g. cells 5-7 which still uses packed_L and pm.expand_packed_triangular to calculate something, but apparently its results aren’t used anywhere. I think your LKJ tutorial is really only suitable for people who are already very experienced in all this. For a beginner it is super-confusing.

I appreciate your tremendous effort already and that you probably have a million other things to do, but if you consider this topic to be important, then it may be worth your time to write a more beginner-friendly tutorial on how to do Multi-Variate Normals in pymc3.

Thanks again for all your efforts!

PS: I bought a cool hat like you have in your profile picture, expecting that it would make it easier to understand all this Bayes stuff, but pymc3 tells me there’s an 89% probability that the hat didn’t really help me, although I’m not quite sure that my model is correct.

1 Like

It’s true this tutorial is not for beginners. I didn’t write it initally, but I think its goal is just to focus on the LKJ distribution and parametrization of the MvNormal, not on the MvNormal per se. So I think its very focused scope is actually a good thing and answers a precise use-case: you’re working with MvNormal and wonder how the LKJ factor can be done with PyMC3 --> boom, this is the tutorial.

This is also why I kept a bit of the old syntax: 1) it explicits what the new syntax does under the hood, to limit a “black-box” usage of pm.LKJCholeskyCov and help people understand concepts; 2) It helps people who were using the old syntax transition to the new one – as it was only introduced only in 3.9.0, we think it’s important to give people time to adapt.

Regarding tutorials dedicated to MvNormal usage, we already have good resources IMO: the Statistical Rethinking repo as you mentioned above, but also the newly revamped radon NB example. When starting out, I remember functions’ docstrings were very useful too.
That being said, we’re always happy to add good new tutorials to the catalogue, so feel free to submit one in a PR :tada:

2 Likes