Hello,
I am trying to sample from a hierarchical model using the Numpyro JAX backend, but get a funny error for some models. To begin, I have sampled successfully (and quickly – thank you v4 devs!) from other models, so the JAX backend is working fine. Unfortunately, because of the problem discussed here, I have had to downgrade my version of jax to v0.2.28 (and consequently jaxlib to v0.3.0). I would hope that updating jax would fix this problem, but it’s not an option yet.
Here is the error message I receive when trying to sample using Numpyro:
Traceback (most recent call last):
  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3397 in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  Input In [77] in <cell line: 1>
    m9_trace = model_fitting_pipeline(
  Input In [8] in model_fitting_pipeline
    trace = sample_model(
  File ~/Developer/haigis-lab/bluishred/bluishred/model_fitting.py:104 in sample_model
    trace = _sample_model(m, sample_kwargs, mcmc_backend)
  File ~/Developer/haigis-lab/bluishred/bluishred/model_fitting.py:37 in _sample_model
    trace = sample_numpyro_nuts(**sample_kwargs)
  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/pymc/sampling_jax.py:483 in sample_numpyro_nuts
    logp_fn = get_jaxified_logp(model, negative_logp=False)
  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/pymc/sampling_jax.py:106 in get_jaxified_logp
    logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])
  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/pymc/sampling_jax.py:99 in get_jaxified_graph
    return jax_funcify(fgraph)
  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/functools.py:889 in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/aesara/link/jax/dispatch.py:669 in jax_funcify_FunctionGraph
    return fgraph_to_python(
  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/aesara/link/utils.py:791 in fgraph_to_python
    fgraph_def = compile_function_src(
  File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/aesara/link/utils.py:609 in compile_function_src
    mod_code = compile(src, filename, mode="exec")
  File /var/folders/r4/qpcdgl_14hbd412snp1jnv300000gn/T/tmpilbd79jk:14
    9 9]})
    ^
IndentationError: unexpected indent
And here is the head of the file /var/folders/r4/qpcdgl_14hbd412snp1jnv300000gn/T/tmpilbd79jk:
def jax_funcified_fgraph(mu_mu_a, sigma_mu_a_log_, mu_a, sigma_a_log_, delta_a, sigma_b_log_, b, sigma_c_log_, c, sigma_d_log_, delta_d, sigma_k_log_, k, alpha_log_):
    # AdvancedSubtensor1(k, TensorConstant{[0 0 0 0 0..9 9 9 9 9]})
    auto_1195180 = subtensor(k, auto_1144647)
    # Elemwise{exp,no_inplace}(sigma_d_log__)
    auto_1192542 = exp(sigma_d_log_)
    # AdvancedSubtensor1(c, TensorConstant{[0 0 0 0 0..9 9 9 9 9]})
    auto_1195178 = subtensor1(c, auto_1144647)
    # AdvancedSubtensor(b, TensorConstant{[0 0 0 0 0..9 9 9 9 9]}, TensorConstant{[0 0 0 0 0..1 1 1 1 1]})
    auto_1192551 = subtensor2(b, auto_1144647, auto_1144655)
    # Elemwise{exp,no_inplace}(sigma_a_log__)
    auto_1192553 = exp(sigma_a_log_)
    # AdvancedSubtensor1(mu_a, TensorConstant{[0 0 0 0 1.. 9 9
     9 9]})
    auto_1195177 = subtensor3(mu_a, auto_1144172)
    # Elemwise{exp,no_inplace}(alpha_log__)
    auto_1192534 = exp(alpha_log_)
    # Elemwise{exp,no_inplace}(sigma_k_log__)
Here is my current understanding: there is some magical code generation by PyMC/Aesara/JAX to produce this temporary python code. In this code, there is an indentation error at line 14 where a comment from the previous line has been split onto a new line without being commented out, leading to a parsing error.
For those familiar with the underlying system, is there anything I can do to fix this? For example, is there a setting I could change to stop the addition of comments in the generated code?