Based on this I can conclude that currently there is no way to use pymc3 ADVI on GPU. Am I wrong or is this a good time to start switching to pymc3 4.0 + aesara?
Yes, support for pygpu is not working and will be dropped. JAX is the way to go but we still have to add VI support for PyMC 4.0 (https://github.com/pymc-devs/pymc/pull/4582). But then that would be the way to go.
Are you sure that you need it though? Usually slow models can be sped up a lot by better parameterization.
Using pymc version ‘4.0.1’; aesara version ‘2.7.3’
Input In [21], in <module>
11 import aesara
---> 12 aesara.config["mode"] = "JAX"
TypeError: 'AesaraConfigParser' object does not support item assignment
Yeah, that’s certainly possible. This is still untested with ADVI and as ADVI is implemented in aesara it all gets compiled to C by default already, while our samplers are written in Python, so using JAX samplers removes Python overhead.
I would imagine you can still get speed-ups with JAX if you run on the GPU.