Unable to get pymc3-JAX to run -- ImportError

Hey all,

I’m excited about pymc3 switching to JAX as its backend!

I wanted to give Thomas’s example notebook a go: https://gist.github.com/twiecki/f0a28dd06620aa86142931c1f10b5434

However, I’ve been unable to get pymc3 running. I’ve run:

pip install git+https://github.com/pymc-devs/Theano-PyMC

to install Theano-PyMC’s main branch, followed by

pip install git+https://github.com/pymc-devs/pymc3.git@pymc3jax

to get pymc3’s JAX branch. However when I try to import pymc3, I get the following error:

>>> import pymc3
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/martin/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/__init__.py", line 42, in <module>
    from .distributions import *
  File "/home/martin/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/distributions/__init__.py", line 15, in <module>
    from . import timeseries
  File "/home/martin/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/distributions/timeseries.py", line 22, in <module>
    from .continuous import get_tau_sigma, Normal, Flat
  File "/home/martin/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/distributions/continuous.py", line 27, in <module>
    from .dist_math import (
  File "/home/martin/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/distributions/dist_math.py", line 28, in <module>
    from theano.scan_module import until
ModuleNotFoundError: No module named 'theano.scan_module'

I’d love to give it a go – is there something else I need to do to get the example to work?

Thanks for your help and best,
Martin

2 Likes

Can you reinstall the pymc3jax branch? Should be fixed now. Let me know if you run into further problems.

1 Like

Thanks, it works (and it’s really cool)!

Just as a quick clarification, you probably meant to say this, but I ended up reinstalling the pymcjax branch, not the master, since only that one has the JAX-based samplers.

yep - @twiecki is merging the 2 branches so :soon:

1 Like