Hello,
I am trying to sample from a hierarchical model using the Numpyro JAX backend, but get a funny error for some models. To begin, I have sampled successfully (and quickly – thank you v4 devs!) from other models, so the JAX backend is working fine. Unfortunately, because of the problem discussed here, I have had to downgrade my version of jax
to v0.2.28 (and consequently jaxlib
to v0.3.0). I would hope that updating jax
would fix this problem, but it’s not an option yet.
Here is the error message I receive when trying to sample using Numpyro:
Traceback (most recent call last):
File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3397 in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
Input In [77] in <cell line: 1>
m9_trace = model_fitting_pipeline(
Input In [8] in model_fitting_pipeline
trace = sample_model(
File ~/Developer/haigis-lab/bluishred/bluishred/model_fitting.py:104 in sample_model
trace = _sample_model(m, sample_kwargs, mcmc_backend)
File ~/Developer/haigis-lab/bluishred/bluishred/model_fitting.py:37 in _sample_model
trace = sample_numpyro_nuts(**sample_kwargs)
File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/pymc/sampling_jax.py:483 in sample_numpyro_nuts
logp_fn = get_jaxified_logp(model, negative_logp=False)
File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/pymc/sampling_jax.py:106 in get_jaxified_logp
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])
File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/pymc/sampling_jax.py:99 in get_jaxified_graph
return jax_funcify(fgraph)
File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/functools.py:889 in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/aesara/link/jax/dispatch.py:669 in jax_funcify_FunctionGraph
return fgraph_to_python(
File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/aesara/link/utils.py:791 in fgraph_to_python
fgraph_def = compile_function_src(
File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/aesara/link/utils.py:609 in compile_function_src
mod_code = compile(src, filename, mode="exec")
File /var/folders/r4/qpcdgl_14hbd412snp1jnv300000gn/T/tmpilbd79jk:14
9 9]})
^
IndentationError: unexpected indent
And here is the head of the file /var/folders/r4/qpcdgl_14hbd412snp1jnv300000gn/T/tmpilbd79jk
:
def jax_funcified_fgraph(mu_mu_a, sigma_mu_a_log_, mu_a, sigma_a_log_, delta_a, sigma_b_log_, b, sigma_c_log_, c, sigma_d_log_, delta_d, sigma_k_log_, k, alpha_log_):
# AdvancedSubtensor1(k, TensorConstant{[0 0 0 0 0..9 9 9 9 9]})
auto_1195180 = subtensor(k, auto_1144647)
# Elemwise{exp,no_inplace}(sigma_d_log__)
auto_1192542 = exp(sigma_d_log_)
# AdvancedSubtensor1(c, TensorConstant{[0 0 0 0 0..9 9 9 9 9]})
auto_1195178 = subtensor1(c, auto_1144647)
# AdvancedSubtensor(b, TensorConstant{[0 0 0 0 0..9 9 9 9 9]}, TensorConstant{[0 0 0 0 0..1 1 1 1 1]})
auto_1192551 = subtensor2(b, auto_1144647, auto_1144655)
# Elemwise{exp,no_inplace}(sigma_a_log__)
auto_1192553 = exp(sigma_a_log_)
# AdvancedSubtensor1(mu_a, TensorConstant{[0 0 0 0 1.. 9 9
9 9]})
auto_1195177 = subtensor3(mu_a, auto_1144172)
# Elemwise{exp,no_inplace}(alpha_log__)
auto_1192534 = exp(alpha_log_)
# Elemwise{exp,no_inplace}(sigma_k_log__)
Here is my current understanding: there is some magical code generation by PyMC/Aesara/JAX to produce this temporary python code. In this code, there is an indentation error at line 14 where a comment from the previous line has been split onto a new line without being commented out, leading to a parsing error.
For those familiar with the underlying system, is there anything I can do to fix this? For example, is there a setting I could change to stop the addition of comments in the generated code?