GPU acceleration with Aesara

Hi there, I have created a model in PyMC4 (which uses Aesara) that runs fine on CPU. Now I’m trying to run it on GPU. In PyMC4 docs, I see that they say codes can be easily run on GPU. I tried to achieve this by setting up a .aesararc file in the home folder as follows:

floatX = float32
device = cuda0

However, what I get from this is the following error:

ValueError                                Traceback (most recent call last)
Input In [1], in <cell line: 1>()
----> 1 import aesara.tensor as at
      2 import aesara
      3 import arviz as az

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/, in <module>
     61         continue
     62     raise RuntimeError("You have the aesara directory in your Python path.")
---> 64 from aesara.configdefaults import config
     65 from aesara.utils import deprecated
     68 change_flags = deprecated("Use aesara.config.change_flags instead!")(
     69     config.change_flags
     70 )

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/, in <module>
   1451 config = aesara.configparser._config
   1453 # The functions below register config variables into the config instance above.
-> 1454 add_basic_configvars()
   1455 add_compile_configvars()
   1456 add_tensor_configvars()

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/, in add_basic_configvars()
    307 config.add(
    308     "cast_policy",
    309     "Rules for implicit type casting",
    320     ),
    321 )
    323 config.add(
    324     "deterministic",
    325     "If `more`, sometimes we will select some implementation that "
    329     in_c_key=False,
    330 )
--> 332 config.add(
    333     "device",
    334     ("Default device for computations. only cpu is supported for now"),
    335     DeviceParam("cpu", mutable=False),
    336     in_c_key=False,
    337 )
    339 config.add(
    340     "force_device",
    341     "Raise an error if we can't use the specified device",
    342     BoolParam(False, mutable=False),
    343     in_c_key=False,
    344 )
    346 config.add(
    347     "conv__assert_shape",
    348     "If True, AbstractConv* ops will verify that user-provided"
    352     in_c_key=False,
    353 )

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/, in AesaraConfigParser.add(self, name, doc, configparam, in_c_key)
    173 # Trigger a read of the value from config files and env vars
    174 # This allow to filter wrong value from the user.
    175 if not callable(configparam.default):
--> 176     configparam.__get__(self, type(self), delete_key=True)
    177 else:
    178     # We do not want to evaluate now the default value
    179     # when it is a callable.
    180     try:

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/, in ConfigParam.__get__(self, cls, type_, delete_key)
    358         else:
    359             val_str = self.default
--> 360     self.__set__(cls, val_str)
    361 return self.val

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/, in ConfigParam.__set__(self, cls, val)
    364 if not self.mutable and hasattr(self, "val"):
    365     raise Exception(
    366         f"Can't change the value of {} config parameter after initialization!"
    367     )
--> 368 applied = self.apply(val)
    369 self.validate(applied)
    370 self.val = applied

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/, in ConfigParam.apply(self, value)
    319 """Applies modifications to a parameter value during assignment.
    321 Typical use cases are casting or the substitution of '~' with the user home directory.
    322 """
    323 if callable(self._apply):
--> 324     return self._apply(value)
    325 return value

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/, in DeviceParam._apply(self, val)
    458 def _apply(self, val):
    459     if val.startswith("opencl") or val.startswith("cuda") or val.startswith("gpu"):
--> 460         raise ValueError(
    461             "You are trying to use the old GPU back-end. "
    462             "It was removed from Aesara."
    463         )
    464     elif val == self.default:
    465         return val

ValueError: You are trying to use the old GPU back-end. It was removed from Aesara.

Particularly, I see the following as part of the above error message:

332 config.add(
    333     "device",
    334     ("Default device for computations. only cpu is supported for now")

Does this mean that GPU support is disabled altogether for now, or is there still a way to get GPU to work?
Thank you in advance for any help!

Aesara does not support GPU directly. Instead support is obtained from the JAX backend.


Thank you very much! This is great is know!
In fact, recently I managed to get access to GPU via Jax, so I should be good to go now. Here’s a thread where the setup for GPU for JAX sampling is discussed, for anybody else who needs help: