Pymc3-3.11.0 with GPU support

Hello pmyc3 experts

I’m trying to run pymc3 with GPU support with following setup:

ubuntu-20.04, kernel 5.4.0-72-generic
I created a conda environment and installed:

pip install pymc-3.11.0
conda install -c conda-forge pygpu

I installed cuda and cudnn from the nvidia site. I can run nvidida-smi and it detects my NVIDIA GeForce GTX 1080

I set up a .theanorc file in my home dir:

floatX = float32
device = cuda0
force_device = True



But when I import pymc3 I get the following:

/home/hadron/myutils/anaconda3/envs/myenv1/lib/python3.8/site-packages/theano/gpuarray/ UserWarning: Your cuDNN version is more recent than Theano. If you encounter problems, try updating Theano or downgrading cuDNN to a version >= v5 and <= v7.
Using cuDNN version 7401 on context None
Mapped name None to device cuda0: NVIDIA GeForce GTX 1080 (0000:01:00.0)
ERROR (theano.gpuarray): Could not initialize pygpu, support disabled
Traceback (most recent call last):
  File "/home/hadron/myutils/anaconda3/envs/myenv1/lib/python3.8/site-packages/theano/gpuarray/", line 262, in <module>
  File "/home/hadron/myutils/anaconda3/envs/myenv1/lib/python3.8/site-packages/theano/gpuarray/", line 251, in use
    optdb.add_tags("gpuarray_opt", "fast_run", "fast_compile")
AttributeError: module 'theano.gpuarray.optdb' has no attribute 'add_tags'

As you can see I used cuda 10.0 and installed cuDNN, but I also tried with others, i.e. more recent versions (starting from cuda 11.3 and working my way down), trying to address the message “Your cuDNN version is more recent than Theano” but that still doesn’t work.
Can anyone help me with that? I’m quite stuck with it

FWIW I am also having this problem. Having spent the better part of the last two days installing and uninstalling CUDA and friends, I came here, stumped. It makes feel less stupid to see a fellow traveller on the road…

Anyway, the warning seems to be triggered by this test:

v = version()
if v >= 7200:
        "Your cuDNN version is more recent than "
        "Theano. If you encounter problems, try "
        "updating Theano or downgrading cuDNN to "
        "a version >= v5 and <= v7."

This makes me think you’d need 7.2.x or less before that warning disappeared.

However, as far as I can tell, it’s related to python code. If you look at gpuarray/ you’ll see that, indeed (as the error of course correctly states), optdb is a module not a class. The method add_tags is a member of SequenceDB, found in graph/ The module graph appears to have been renamed recently from gof?

You can prove it by putting a print(optdb) right before the error occurrs:

$ THEANO_FLAGS="floatX=float32,device=cuda" python -c 'import theano'
Using cuDNN version 7605 on context None
Mapped name None to device cuda: NVIDIA GeForce GTX TITAN X (0000:01:00.0)
<module 'theano.gpuarray.optdb' from '.../site-packages/theano/gpuarray/'>

In fact, there exists a file, theano/gpuarray/ that might be getting imported instead? I thought this might be a result of the order of things getting imported (since this is happening inside a file, so I re-imported optdb:

    if default_to_move_computation_to_gpu:
        from theano.compile import optdb
        optdb.add_tags("gpuarray_opt", "fast_run", "fast_compile")
        optdb.add_tags("gpua_scanOp_make_inplace", "fast_run")

which yields the following error:

Traceback (most recent call last):
  File ".../site-packages/theano/gpuarray/", line 268, in <module>
  File ".../site-packages/theano/gpuarray/", line 255, in use
  File ".../site-packages/theano/graph/", line 427, in __str__
  File ".../site-packages/theano/graph/", line 419, in print_summary

At this point, I’m having trouble understanding how everyone doesn’t have this problem…


edit: I’m on CUDA 10.2, cudnn 7.6.5, ubuntu 16.04, pymc 3.11.2, pygpu 0.7.6, theano-pymc 1.1.2.

I would try sampling on the GPU using the new JAX backend: Inefficient use of pmap in jax numpyro sampler? · Issue #4288 · pymc-devs/pymc3 · GitHub for some info.

try sampling on the GPU using the new JAX backend

Can you elaborate a bit on how to do that? Are there instructions somewhere?

Following @junpenglao’s gist, I tried:

$ python -c 'import pymc3.sampling_jax'
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/jrporter/miniconda2/envs/pymc/lib/python3.8/site-packages/pymc3/", line 18, in <module>
    from import jax_funcify
  File "/home/jrporter/miniconda2/envs/pymc/lib/python3.8/site-packages/theano/link/jax/", line 86, in <module>
  File "/home/jrporter/miniconda2/envs/pymc/lib/python3.8/site-packages/jax/", line 167, in disable_omnistaging
    raise Exception(
Exception: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see


Yes, you need to downgrade JAX.

Any solutions for this for advi fitting?

Is that not working with the JAX backend?

Where can I find how to do this? There is only documentation for sampling with jax.

Try setting

# Disable C compilation by default
theano.config.cxx = ""
# This will make the JAX Linker the default
theano.config.mode = "JAX"

And then just running ADVI as usual.