Automatic Conjugacy Detection

Does PyMC3 have any kind of system to automatically identify when two distributions are conjugate, and then take advantage of the conjugacy for faster sampling? For larger models, using conjugate priors where such priors are just as good as non-conjugate priors can speed up sampling dramatically.

At the moment PyMC3 cannot do this automatically and users have to choose a compound distribution if they wish to exploit conjugacy (eg. with the BetaBinomial or Dirichlet Multinomial distributions).

Outside of a few compound distributions that are already implemented, the user would have to write his own if they wished to exploit this.

Automatic reparametrization is something we probably want to implement somewhere down the road, but we still need to implement the compound distributions.

On the other hand NUTS is often fast enough at sampling multiple variables that conjugacy is not so critical.

1 Like

There is also an exotic example of how to build a custom sampler in PyMC3 that takes advantage of conjugacy: Using a custom step method for sampling from locally conjugate posterior distributions — PyMC3 3.11.2 documentation