JAX Numpyro backend "IndentationError: unexpected indent"

Hello,
I am trying to sample from a hierarchical model using the Numpyro JAX backend, but get a funny error for some models. To begin, I have sampled successfully (and quickly – thank you v4 devs!) from other models, so the JAX backend is working fine. Unfortunately, because of the problem discussed here, I have had to downgrade my version of jax to v0.2.28 (and consequently jaxlib to v0.3.0). I would hope that updating jax would fix this problem, but it’s not an option yet.

Here is the error message I receive when trying to sample using Numpyro:

Traceback (most recent call last):

  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3397 in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)

  Input In [77] in <cell line: 1>
    m9_trace = model_fitting_pipeline(

  Input In [8] in model_fitting_pipeline
    trace = sample_model(

  File ~/Developer/haigis-lab/bluishred/bluishred/model_fitting.py:104 in sample_model
    trace = _sample_model(m, sample_kwargs, mcmc_backend)

  File ~/Developer/haigis-lab/bluishred/bluishred/model_fitting.py:37 in _sample_model
    trace = sample_numpyro_nuts(**sample_kwargs)

  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/pymc/sampling_jax.py:483 in sample_numpyro_nuts
    logp_fn = get_jaxified_logp(model, negative_logp=False)

  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/pymc/sampling_jax.py:106 in get_jaxified_logp
    logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])

  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/pymc/sampling_jax.py:99 in get_jaxified_graph
    return jax_funcify(fgraph)

  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/functools.py:889 in wrapper
    return dispatch(args[0].__class__)(*args, **kw)

  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/aesara/link/jax/dispatch.py:669 in jax_funcify_FunctionGraph
    return fgraph_to_python(

  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/aesara/link/utils.py:791 in fgraph_to_python
    fgraph_def = compile_function_src(

  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/aesara/link/utils.py:609 in compile_function_src
    mod_code = compile(src, filename, mode="exec")

  File /var/folders/r4/qpcdgl_14hbd412snp1jnv300000gn/T/tmpilbd79jk:14
    9 9]})
    ^
IndentationError: unexpected indent

And here is the head of the file /var/folders/r4/qpcdgl_14hbd412snp1jnv300000gn/T/tmpilbd79jk:

def jax_funcified_fgraph(mu_mu_a, sigma_mu_a_log_, mu_a, sigma_a_log_, delta_a, sigma_b_log_, b, sigma_c_log_, c, sigma_d_log_, delta_d, sigma_k_log_, k, alpha_log_):
    # AdvancedSubtensor1(k, TensorConstant{[0 0 0 0 0..9 9 9 9 9]})
    auto_1195180 = subtensor(k, auto_1144647)
    # Elemwise{exp,no_inplace}(sigma_d_log__)
    auto_1192542 = exp(sigma_d_log_)
    # AdvancedSubtensor1(c, TensorConstant{[0 0 0 0 0..9 9 9 9 9]})
    auto_1195178 = subtensor1(c, auto_1144647)
    # AdvancedSubtensor(b, TensorConstant{[0 0 0 0 0..9 9 9 9 9]}, TensorConstant{[0 0 0 0 0..1 1 1 1 1]})
    auto_1192551 = subtensor2(b, auto_1144647, auto_1144655)
    # Elemwise{exp,no_inplace}(sigma_a_log__)
    auto_1192553 = exp(sigma_a_log_)
    # AdvancedSubtensor1(mu_a, TensorConstant{[0 0 0 0 1.. 9 9
     9 9]})
    auto_1195177 = subtensor3(mu_a, auto_1144172)
    # Elemwise{exp,no_inplace}(alpha_log__)
    auto_1192534 = exp(alpha_log_)
    # Elemwise{exp,no_inplace}(sigma_k_log__)

Here is my current understanding: there is some magical code generation by PyMC/Aesara/JAX to produce this temporary python code. In this code, there is an indentation error at line 14 where a comment from the previous line has been split onto a new line without being commented out, leading to a parsing error.

For those familiar with the underlying system, is there anything I can do to fix this? For example, is there a setting I could change to stop the addition of comments in the generated code?

This is an Aesara bug, can you open an issue in the repository? GitHub - aesara-devs/aesara: Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.

1 Like

How do you know it is a problem with Aesara? I am using v2.6.2 and PyMC v4.0.0b6, but the version of JAX are fairly old. If you’re sure I should open an Issue, what would I say? I’m not certain I can provide a minimal reproducible example.

Also, I’m pretty sure I had this same model sampling fine when I was using the latest versions of jax and jaxlib.

Because it seems to be a formatting bug when adding comments to the source file. That should not happen. Even the error message by itself without the reproducible example might be useful.

Got it, thank you for your feedback. I have opened an issue: Bug in formatting comments in generated code files. · Issue #971 · aesara-devs/aesara · GitHub

1 Like

I just got this as well.

Are you on PyMC 4.0?

Yes, and the latest release of Aesara, though I had to pip install Aesara after installing PyMC to get v2.7.1.

Yes. PyMC 4 / Aesara 2.6.2

This was fixed in Aesara already, but we haven’t updated the dependency due to some incompatibilities with PyMC. They might not affect you, so you can always try to update Aesara manually on your end.

1 Like

Sorry, I apologize, I got switched around on which problem this was referring to. The latest update to Aesara fixed this for me (see the Issue linked above).

Seems like updating Aesara worked. Thank you.

I get this broken comment error as well. I am using Aesara 2.6.6, jax 0.3.13, jaxlib 0.3.10 . Could somebody please give the output of their “pip list” so that I may compare with my output, which I provide herein (I am using python 3.9.12). I created a new conda environment and installed pymc via:

pip install git+https://github.com/pymc-devs/pymc.git
pip install jupyterlab
pip install numpyro

Here is my output from pip list:

absl-py              1.1.0
aeppl                0.0.31
aesara               2.6.6
anyio                3.6.1
argon2-cffi          21.3.0
argon2-cffi-bindings 21.2.0
arviz                0.12.1
asttokens            2.0.5
attrs                21.4.0
Babel                2.10.1
backcall             0.2.0
beautifulsoup4       4.11.1
bleach               5.0.0
cachetools           5.2.0
certifi              2022.5.18.1
cffi                 1.15.0
cftime               1.6.0
charset-normalizer   2.0.12
cloudpickle          2.1.0
cons                 0.4.5
cycler               0.11.0
debugpy              1.6.0
decorator            5.1.1
defusedxml           0.7.1
entrypoints          0.4
etuples              0.3.5
executing            0.8.3
fastjsonschema       2.15.3
fastprogress         1.0.2
filelock             3.7.1
flatbuffers          2.0
fonttools            4.33.3
idna                 3.3
importlib-metadata   4.11.4
ipykernel            6.13.1
ipython              8.4.0
ipython-genutils     0.2.0
jax                  0.3.13
jaxlib               0.3.10
jedi                 0.18.1
Jinja2               3.1.2
json5                0.9.8
jsonschema           4.6.0
jupyter-client       7.3.4
jupyter-core         4.10.0
jupyter-server       1.17.1
jupyterlab           3.4.3
jupyterlab-pygments  0.2.2
jupyterlab-server    2.14.0
kiwisolver           1.4.2
logical-unification  0.4.5
MarkupSafe           2.1.1
matplotlib           3.5.2
matplotlib-inline    0.1.3
miniKanren           1.0.3
mistune              0.8.4
multipledispatch     0.6.0
nbclassic            0.3.7
nbclient             0.6.4
nbconvert            6.5.0
nbformat             5.4.0
nest-asyncio         1.5.5
netCDF4              1.5.8
notebook             6.4.12
notebook-shim        0.1.0
numpy                1.22.4
numpyro              0.9.2
opt-einsum           3.3.0
packaging            21.3
pandas               1.4.2
pandocfilters        1.5.0
parso                0.8.3
pexpect              4.8.0
pickleshare          0.7.5
Pillow               9.1.1
pip                  21.2.4
prometheus-client    0.14.1
prompt-toolkit       3.0.29
psutil               5.9.1
ptyprocess           0.7.0
pure-eval            0.2.2
pycparser            2.21
Pygments             2.12.0
pymc                 4.0.0
pyparsing            3.0.9
pyrsistent           0.18.1
python-dateutil      2.8.2
pytz                 2022.1
pyzmq                23.1.0
requests             2.28.0
scipy                1.8.1
Send2Trash           1.8.0
setuptools           61.2.0
six                  1.16.0
sniffio              1.2.0
soupsieve            2.3.2.post1
stack-data           0.2.0
terminado            0.15.0
tinycss2             1.1.1
toolz                0.11.2
tornado              6.1
tqdm                 4.64.0
traitlets            5.2.2.post1
typing_extensions    4.2.0
urllib3              1.26.9
wcwidth              0.2.5
webencodings         0.5.1
websocket-client     1.3.2
wheel                0.37.1
xarray               2022.3.0
xarray-einstats      0.2.2
zipp                 3.8.0

I fixed the problem by installing Aesara 2.7.1 . I am wondering why Aesara 2.7.1 is not installed by default when installing pymc version 4.0. Thanks.

Yes, I wonder why Aesara 2.7.1 is not installed along with pymc. Perhaps the requirements.txt should be updated?

I don’t know exactly why, but @ricardoV94 suggested it’s because there are some incompatibilities with PyMC still.

1 Like