With all the changes, what are the options for ADVI training on GPUs?


I am trying to understand what are the current options for using ADVI on GPU. These 3 options come to mind:

  1. aesara compilation to C (using pygpu)
  2. theano-pymc compilation to C (using pygpu)
  3. aesara + JAX as mentioned here Pymc3-3.11.0 with GPU support - #9 by twiecki

It seems that approach #1 is not yet recommended (Aesara, theano, theano-pymc - #3 by ricardoV94) and also it does not work in practice Moving to pymc3 v4 (replaced theano with aesera) by vitkl · Pull Request #59 · BayraktarLab/cell2location · GitHub<.
Approach #2 does not work for me with the same errors as discussed here https://discourse.pymc.io/t/pymc3-3-11-0-with-gpu-support/.
Approach #3 seems quite experimental. In addition, I found that JAX uses 2x GPU memory compared to pymc3+theano and pyro.

Based on this I can conclude that currently there is no way to use pymc3 ADVI on GPU. Am I wrong or is this a good time to start switching to pymc3 4.0 + aesara?

Yes, support for pygpu is not working and will be dropped. JAX is the way to go but we still have to add VI support for PyMC 4.0 (https://github.com/pymc-devs/pymc/pull/4582). But then that would be the way to go.

Are you sure that you need it though? Usually slow models can be sped up a lot by better parameterization.

Since pymc 4.0 has been released, what’s the update on this? Does ADVI works with aesara + JAX now and how to set it up?

In principle it should, you can try:

import aesara
aesara.config["mode"] = "JAX"

And run ADVI.

1 Like

Using pymc version ‘4.0.1’; aesara version ‘2.7.3’

Input In [21], in <module>
     11 import aesara
---> 12 aesara.config["mode"] = "JAX"

TypeError: 'AesaraConfigParser' object does not support item assignment

it seems that

import aesara
aesara.config.mode = "JAX"

works, but somehow it made the inference even a bit slower?

Yeah, that’s certainly possible. This is still untested with ADVI and as ADVI is implemented in aesara it all gets compiled to C by default already, while our samplers are written in Python, so using JAX samplers removes Python overhead.

I would imagine you can still get speed-ups with JAX if you run on the GPU.