Unable to get pymc3-JAX to run -- ImportError

Hey all,

I’m excited about pymc3 switching to JAX as its backend!

I wanted to give Thomas’s example notebook a go: https://gist.github.com/twiecki/f0a28dd06620aa86142931c1f10b5434

However, I’ve been unable to get pymc3 running. I’ve run:

pip install git+https://github.com/pymc-devs/Theano-PyMC

to install Theano-PyMC’s main branch, followed by

pip install git+https://github.com/pymc-devs/pymc3.git@pymc3jax

to get pymc3’s JAX branch. However when I try to import pymc3, I get the following error:

>>> import pymc3
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/martin/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/__init__.py", line 42, in <module>
    from .distributions import *
  File "/home/martin/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/distributions/__init__.py", line 15, in <module>
    from . import timeseries
  File "/home/martin/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/distributions/timeseries.py", line 22, in <module>
    from .continuous import get_tau_sigma, Normal, Flat
  File "/home/martin/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/distributions/continuous.py", line 27, in <module>
    from .dist_math import (
  File "/home/martin/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/distributions/dist_math.py", line 28, in <module>
    from theano.scan_module import until
ModuleNotFoundError: No module named 'theano.scan_module'

I’d love to give it a go – is there something else I need to do to get the example to work?

Thanks for your help and best,


Can you reinstall the pymc3jax branch? Should be fixed now. Let me know if you run into further problems.

Thanks, it works (and it’s really cool)!

Just as a quick clarification, you probably meant to say this, but I ended up reinstalling the pymcjax branch, not the master, since only that one has the JAX-based samplers.

yep - @twiecki is merging the 2 branches so :soon:

