I am trying to understand what are the current options for using ADVI on GPU. These 3 options come to mind:
- aesara compilation to C (using pygpu)
- theano-pymc compilation to C (using pygpu)
- aesara + JAX as mentioned here Pymc3-3.11.0 with GPU support - #9 by twiecki
It seems that approach #1 is not yet recommended (Aesara, theano, theano-pymc - #3 by ricardoV94) and also it does not work in practice Moving to pymc3 v4 (replaced theano with aesera) by vitkl · Pull Request #59 · BayraktarLab/cell2location · GitHub<.
Approach #2 does not work for me with the same errors as discussed here https://discourse.pymc.io/t/pymc3-3-11-0-with-gpu-support/.
Approach #3 seems quite experimental. In addition, I found that JAX uses 2x GPU memory compared to pymc3+theano and pyro.
Based on this I can conclude that currently there is no way to use pymc3 ADVI on GPU. Am I wrong or is this a good time to start switching to pymc3 4.0 + aesara?