NameError: unbound axis name raised during transformation of variables after sample_numpyro_nuts

,

Hello,

Has anyone come across an error like this before? If so, any advice would be appreciated.

My likelihood function is a jax function which is then wrapped in PyTensor and subsequently unwrapped so it can be sampled using numpyro as per the method here.

As can be seen from the traceback below, the sampling seems to finish without issue, it is during the transformation of variables when the error occurs.

I’ve recently updated to pymc v5.0.1. Installed via method here.

Complete traceback:

(pymc5) nick@u64:/media/sf_thesis/concrete/repo$ python minimal1DBEFV.py
/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/jax.py:39: UserWarning: This module is experimental.
  warnings.warn("This module is experimental.")
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Compiling...
Compilation time =  0:00:42.358845
Sampling...
Running chain 0: 100%|█████████████████████████████████████████████████████████████████████████████████████| 2500/2500 [10:59:46<00:00, 15.83s/it]
Running chain 1: 100%|█████████████████████████████████████████████████████████████████████████████████████| 2500/2500 [10:59:46<00:00, 15.83s/it]
Sampling time =  11:00:00.150459
Transforming variables...
Traceback (most recent call last):
  File "/media/sf_thesis/concrete/repo/minimal1DBEFV.py", line -1, in <module>
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/jax.py", line -1, in sample_numpyro_nuts
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/jax.py", line -1, in _postprocess_samples
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/contextlib.py", line -1, in inner
  File "/tmp/tmphh43n0fy", line -1, in jax_funcified_fgraph
  File "/media/sf_thesis/concrete/repo/minimal1DBEFV.py", line -1, in f
  File "/media/sf_thesis/concrete/repo/minimal1DBEFV.py", line 208, in get_SE_Value
    SE = SquaredError(T_trim, D)

  File "/media/sf_thesis/concrete/repo/minimal1DBEFV.py", line -1, in timeStepping
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NameError: unbound axis name: <UniqueResource None 0>. The following axis names (e.g. defined by pmap) are available to collective operations: [<UniqueResource chain 3>, <UniqueResource samples 4>]

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/media/sf_thesis/concrete/repo/minimal1DBEFV.py", line 392, in <module>
    idata = pm.sampling_jax.sample_numpyro_nuts(2000, tune=500, chains=2, chain_method='parallel', postprocessing_backend='cpu', postprocessing_chunks=2)#, idata_kwargs={'log_likelihood': False})
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/jax.py", line 662, in sample_numpyro_nuts
    result = _postprocess_samples(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/jax.py", line 161, in _postprocess_samples
    return f(*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py", line 259, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
                                                                ^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py", line 245, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(
                              ^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py", line 60, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py", line 54, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/interpreters/partial_eval.py", line 1981, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/experimental/maps.py", line 821, in body
    result = f.call_wrapped(
             ^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 197, in jaxpr_as_fun
    return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 395, in eval_jaxpr
    ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py", line 999, in scan_bind
    return core.AxisPrimitive.bind(scan_p, *args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 2445, in bind
    axis_main = max((axis_frame(a).main_trace for a in used_axis_names(self, params)),
                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 2356, in used_axis_names
    subst_axis_names(primitive, params, subst)
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 2375, in subst_axis_names
    new_params[name] = subst_axis_names_jaxpr(jaxpr, shadowed_subst)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 2431, in subst_axis_names_jaxpr
    subst.axis_names |= used_axis_names_jaxpr(jaxpr)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 2426, in used_axis_names_jaxpr
    do_subst_axis_names_jaxpr(jaxpr, subst)
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 2414, in do_subst_axis_names_jaxpr
    invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 2414, in <listcomp>
    invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 2391, in subst_axis_names_var
    named_shape = {name: axis_frame(name).size for name in names}
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 2391, in <dictcomp>
    named_shape = {name: axis_frame(name).size for name in names}
                         ^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 2339, in axis_frame
    raise NameError(
jax._src.traceback_util.UnfilteredStackTrace: NameError: unbound axis name: <UniqueResource None 0>. The following axis names (e.g. defined by pmap) are available to collective operations: [<UniqueResource chain 3>, <UniqueResource samples 4>]

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/media/sf_thesis/concrete/repo/minimal1DBEFV.py", line 392, in <module>
    idata = pm.sampling_jax.sample_numpyro_nuts(2000, tune=500, chains=2, chain_method='parallel', postprocessing_backend='cpu', postprocessing_chunks=2)#, idata_kwargs={'log_likelihood': False})
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/jax.py", line 662, in sample_numpyro_nuts
    result = _postprocess_samples(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/jax.py", line 161, in _postprocess_samples
    return f(*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/experimental/maps.py", line 627, in fun_mapped
    out_flat = xmap_p.bind(fun_flat, *args_flat, **params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/experimental/maps.py", line 852, in bind
    return core.map_bind(self, fun, *args, in_axes=in_axes, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 2173, in map_bind
    primitive.process(top_trace, fun, tracers, params))
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/experimental/maps.py", line 855, in process
    return trace.process_xmap(self, fun, tracers, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/core.py", line 715, in process_call
    return primitive.impl(f, *tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/experimental/maps.py", line 655, in xmap_impl
    xmap_callable = make_xmap_callable(
                    ^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/linear_util.py", line 303, in memoized_fun
    ans = call(fun, *args)
          ^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/experimental/maps.py", line 732, in make_xmap_callable
    return dispatch.lower_xla_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/_src/dispatch.py", line 519, in lower_xla_callable
    lowering_result = mlir.lower_jaxpr_to_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 707, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 988, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1122, in jaxpr_subcomp
    ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/experimental/maps.py", line 1327, in _xmap_lowering_rule
    return _xmap_lowering_rule_replica(ctx, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/experimental/maps.py", line 1370, in _xmap_lowering_rule_replica
    vectorized_jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(f, local_avals)
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/interpreters/partial_eval.py", line 1981, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nick/anaconda3/envs/pymc5/lib/python3.11/site-packages/jax/experimental/maps.py", line 825, in looped_f
    _, stacked_results = lax.scan(body, 0, (), length=loop_length)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NameError: unbound axis name: <UniqueResource None 0>. The following axis names (e.g. defined by pmap) are available to collective operations: [<UniqueResource chain 3>, <UniqueResource samples 4>]

Versions:

# packages in environment at /home/nick/anaconda3/envs/pymc5:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
absl-py                   1.3.0                    pypi_0    pypi
appdirs                   1.4.4              pyhd3eb1b0_0  
arviz                     0.14.0             pyhd8ed1ab_0    conda-forge
atk-1.0                   2.38.0               hd4edc92_1    conda-forge
binutils                  2.39                 hdd6e379_1    conda-forge
binutils_impl_linux-64    2.39                 he00db2b_1    conda-forge
binutils_linux-64         2.39                h5fc0e48_11    conda-forge
blackjax                  0.9.6                    pypi_0    pypi
blas                      2.116                  openblas    conda-forge
blas-devel                3.9.0           16_linux64_openblas    conda-forge
brotli                    1.0.9                h5eee18b_7  
brotli-bin                1.0.9                h5eee18b_7  
brotlipy                  0.7.0           py311hd4cff14_1005    conda-forge
bzip2                     1.0.8                h7b6447c_0  
c-ares                    1.18.1               h7f8727e_0  
c-compiler                1.5.2                h0b41bf4_0    conda-forge
ca-certificates           2022.12.7            ha878542_0    conda-forge
cachetools                4.2.2              pyhd3eb1b0_0  
cairo                     1.16.0            ha61ee94_1014    conda-forge
certifi                   2022.12.7          pyhd8ed1ab_0    conda-forge
cffi                      1.15.1          py311h409f033_3    conda-forge
cftime                    1.6.2           py311h4c7f6c3_1    conda-forge
charset-normalizer        2.0.4              pyhd3eb1b0_0  
cloudpickle               2.0.0              pyhd3eb1b0_0  
cons                      0.4.5              pyhd8ed1ab_0    conda-forge
contourpy                 1.0.6           py311h4dd048b_0    conda-forge
cryptography              39.0.0          py311h9b4c7bb_0    conda-forge
curl                      7.87.0               hdc1c0ab_0    conda-forge
cxx-compiler              1.5.2                hf52228f_0    conda-forge
cycler                    0.11.0             pyhd3eb1b0_0  
etuples                   0.3.8              pyhd8ed1ab_0    conda-forge
expat                     2.5.0                h27087fc_0    conda-forge
fastprogress              1.0.0              pyhb85f177_0  
filelock                  3.6.0              pyhd3eb1b0_0  
font-ttf-dejavu-sans-mono 2.37                 hab24e00_0    conda-forge
font-ttf-inconsolata      3.000                h77eed37_0    conda-forge
font-ttf-source-code-pro  2.038                h77eed37_0    conda-forge
font-ttf-ubuntu           0.83                 hab24e00_0    conda-forge
fontconfig                2.14.1               hc2a2eb6_0    conda-forge
fonts-conda-ecosystem     1                             0    conda-forge
fonts-conda-forge         1                             0    conda-forge
fonttools                 4.25.0             pyhd3eb1b0_0  
freetype                  2.12.1               h4a9f257_0  
fribidi                   1.0.10               h36c2ea0_0    conda-forge
gcc                       11.3.0              h02d0930_11    conda-forge
gcc_impl_linux-64         11.3.0              hab1b70f_19    conda-forge
gcc_linux-64              11.3.0              he6f903b_11    conda-forge
gdk-pixbuf                2.42.10              h05c8ddd_0    conda-forge
gettext                   0.21.1               h27087fc_0    conda-forge
giflib                    5.2.1                h36c2ea0_2    conda-forge
graphite2                 1.3.13            h58526e2_1001    conda-forge
graphviz                  7.0.6                h2e5815a_0    conda-forge
gtk2                      2.24.33              h90689f9_2    conda-forge
gts                       0.7.6                h64030ff_2    conda-forge
gxx                       11.3.0              h02d0930_11    conda-forge
gxx_impl_linux-64         11.3.0              hab1b70f_19    conda-forge
gxx_linux-64              11.3.0              hc203a17_11    conda-forge
harfbuzz                  6.0.0                h8e241bc_0    conda-forge
hdf4                      4.2.15               h9772cbc_5    conda-forge
hdf5                      1.12.2          nompi_h4df4325_101    conda-forge
icu                       70.1                 h27087fc_0    conda-forge
idna                      3.3                pyhd3eb1b0_0  
intel-openmp              2022.1.0          h9e868ea_3769  
jax                       0.4.1                    pypi_0    pypi
jaxlib                    0.4.1                    pypi_0    pypi
jaxopt                    0.5.5                    pypi_0    pypi
jpeg                      9e                   h7f8727e_0  
kernel-headers_linux-64   2.6.32              he073ed8_15    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
kiwisolver                1.4.4           py311h4dd048b_1    conda-forge
krb5                      1.20.1               h81ceb04_0    conda-forge
lcms2                     2.14                 hfd0df8a_1    conda-forge
ld_impl_linux-64          2.39                 hcc3a1bd_1    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libaec                    1.0.6                h9c3ff4c_0    conda-forge
libblas                   3.9.0           16_linux64_openblas    conda-forge
libbrotlicommon           1.0.9                h5eee18b_7  
libbrotlidec              1.0.9                h5eee18b_7  
libbrotlienc              1.0.9                h5eee18b_7  
libcblas                  3.9.0           16_linux64_openblas    conda-forge
libcurl                   7.87.0               hdc1c0ab_0    conda-forge
libdeflate                1.14                 h166bdaf_0    conda-forge
libedit                   3.1.20221030         h5eee18b_0  
libev                     4.33                 h7f8727e_1  
libffi                    3.4.2                h6a678d5_6  
libgcc                    7.2.0                h69d50b8_2  
libgcc-devel_linux-64     11.3.0              h210ce93_19    conda-forge
libgcc-ng                 12.2.0              h65d4601_19    conda-forge
libgd                     2.3.3                h5aea950_4    conda-forge
libgfortran-ng            12.2.0              h69a702a_19    conda-forge
libgfortran5              12.2.0              h337968e_19    conda-forge
libglib                   2.74.1               h606061b_1    conda-forge
libgomp                   12.2.0              h65d4601_19    conda-forge
libiconv                  1.17                 h166bdaf_0    conda-forge
libjpeg-turbo             2.1.4                h166bdaf_0    conda-forge
liblapack                 3.9.0           16_linux64_openblas    conda-forge
liblapacke                3.9.0           16_linux64_openblas    conda-forge
libnetcdf                 4.8.1           nompi_h261ec11_106    conda-forge
libnghttp2                1.51.0               hff17c54_0    conda-forge
libnsl                    2.0.0                h5eee18b_0  
libopenblas               0.3.21          pthreads_h78a6416_3    conda-forge
libpng                    1.6.39               h753d276_0    conda-forge
librsvg                   2.54.4               h7abd40a_0    conda-forge
libsanitizer              11.3.0              h239ccf8_19    conda-forge
libsqlite                 3.40.0               h753d276_0    conda-forge
libssh2                   1.10.0               hf14f497_3    conda-forge
libstdcxx-devel_linux-64  11.3.0              h210ce93_19    conda-forge
libstdcxx-ng              12.2.0              h46fd767_19    conda-forge
libtiff                   4.5.0                h82bc61c_0    conda-forge
libtool                   2.4.7                h27087fc_0    conda-forge
libuuid                   2.32.1            h7f98852_1000    conda-forge
libwebp                   1.2.4                h1daa5a0_1    conda-forge
libwebp-base              1.2.4                h5eee18b_0  
libxcb                    1.13              h7f98852_1004    conda-forge
libxml2                   2.10.3               h7463322_0    conda-forge
libzip                    1.9.2                hc929e4a_1    conda-forge
libzlib                   1.2.13               h166bdaf_4    conda-forge
llvm-openmp               14.0.6               h9e868ea_0  
logical-unification       0.4.5              pyhd8ed1ab_0    conda-forge
lz4-c                     1.9.4                h6a678d5_0  
matplotlib-base           3.6.2           py311he728205_0    conda-forge
minikanren                1.0.3              pyhd8ed1ab_0    conda-forge
mkl                       2022.1.0           hc2b9512_224  
mkl-service               2.4.0           py311hb711fc7_0    conda-forge
multipledispatch          0.6.0                      py_0    conda-forge
munkres                   1.1.4                      py_0  
ncurses                   6.3                  h5eee18b_3  
netcdf4                   1.6.2           nompi_py311hc6fcf29_100    conda-forge
numpy                     1.24.1          py311hbde0eaa_0    conda-forge
numpyro                   0.10.1                   pypi_0    pypi
openblas                  0.3.21          pthreads_h320a7e8_3    conda-forge
openjpeg                  2.5.0                hfec8fc6_2    conda-forge
openssl                   3.0.7                h0b41bf4_1    conda-forge
opt-einsum                3.3.0                    pypi_0    pypi
packaging                 21.3               pyhd3eb1b0_0  
pandas                    1.5.2           py311h8b32b4d_0    conda-forge
pango                     1.50.12              hd33c08f_1    conda-forge
pcre2                     10.40                hc3806b6_0    conda-forge
pillow                    9.4.0           py311h104bd61_0    conda-forge
pip                       22.3.1             pyhd8ed1ab_0    conda-forge
pixman                    0.40.0               h36c2ea0_0    conda-forge
pooch                     1.4.0              pyhd3eb1b0_0  
pthread-stubs             0.3                  h0ce48e5_1  
pycparser                 2.21               pyhd3eb1b0_0  
pymc                      5.0.1                hd8ed1ab_0    conda-forge
pymc-base                 5.0.1              pyhd8ed1ab_0    conda-forge
pyopenssl                 22.0.0             pyhd3eb1b0_0  
pyparsing                 3.0.4              pyhd3eb1b0_0  
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
pytensor                  2.8.11          py311h02ec4da_1    conda-forge
pytensor-base             2.8.11          py311h38be061_1    conda-forge
python                    3.11.0          ha86cf86_0_cpython    conda-forge
python-dateutil           2.8.2              pyhd3eb1b0_0  
python-graphviz           0.20.1             pyh22cad53_0    conda-forge
python_abi                3.11                    3_cp311    conda-forge
pytz                      2021.3             pyhd3eb1b0_0  
readline                  8.2                  h5eee18b_0  
requests                  2.27.1             pyhd3eb1b0_0  
scipy                     1.10.0          py311h8e6699e_0    conda-forge
setuptools                65.6.3             pyhd8ed1ab_0    conda-forge
six                       1.16.0             pyhd3eb1b0_1  
sysroot_linux-64          2.12                he073ed8_15    conda-forge
tk                        8.6.12               h1ccaba5_0  
toolz                     0.11.2             pyhd3eb1b0_0  
tqdm                      4.64.1                   pypi_0    pypi
typing-extensions         4.1.1                hd3eb1b0_0  
typing_extensions         4.1.1              pyh06a4308_0  
tzdata                    2022g                h04d1e81_0  
urllib3                   1.26.8             pyhd3eb1b0_0  
wheel                     0.37.1             pyhd3eb1b0_0  
xarray                    2022.12.0          pyhd8ed1ab_0    conda-forge
xarray-einstats           0.4.0              pyhd8ed1ab_0    conda-forge
xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
xorg-libice               1.0.10               h7f98852_0    conda-forge
xorg-libsm                1.2.3             hd9c2040_1000    conda-forge
xorg-libx11               1.7.2                h7f98852_0    conda-forge
xorg-libxau               1.0.9                h7f98852_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xorg-libxext              1.3.4                h7f98852_1    conda-forge
xorg-libxrender           0.9.10            h7f98852_1003    conda-forge
xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
xorg-xextproto            7.3.0             h7f98852_1002    conda-forge
xorg-xproto               7.0.31            h7f98852_1007    conda-forge
xz                        5.2.8                h5eee18b_0  
zlib                      1.2.13               h166bdaf_4    conda-forge
zstd                      1.5.2                ha4553b6_0  

Thanks!

Here’s a (close to) minimal working example that causes the error.

import numpy as np
import jax
import jax.numpy as jnp
from jax import lax

import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify

import arviz as az
import pymc as pm
import pymc.sampling_jax

TAmb = np.array([26.2, 29.9, 32.7, 34.5, 30.7, 28.8, 29.2, 28.8, 28.7, 29.1])
T0 = 30.
dt = 0.25
alpha = 0.05
Td = 31.

def CalTemp(carry, x):
    Tn_, dt, a = carry
    TAmbn = x
    Tn = Tn_ + a*dt*(Tn_ - TAmbn)
    return (Tn, dt, a), Tn

def get_T_End(T0, dt, alpha, TAmb):
    final, T_All = lax.scan(CalTemp, (T0, dt, alpha), TAmb)
    return T_All[-1]

def get_Squared_Error(T0, dt, alpha, TAmb, Td):
    T_End = get_T_End(T0, dt, alpha, TAmb)
    return (T_End - Td)**2

class BEFV1DLogpValueOp(Op):

    def __init__(self, T0, dt, TAmb, Td):
        self.T0 = T0
        self.dt = dt
        self.TAmb = TAmb
        self.Td = Td

    def make_node(self, alpha,):
        inputs = [alpha]
        outputs = [pt.dscalar('SE')]
        return Apply(self, inputs, outputs)

    def perform(self, node, input, outputs):
        alpha, = input
        SE = get_Squared_Error(self.T0, self.dt, alpha, self.TAmb, self.Td)
        outputs[0][0] = np.asarray(SE, dtype=node.outputs[0].dtype)

BEFV1D_SE_value_op = BEFV1DLogpValueOp(T0, dt, TAmb, Td)

@jax_funcify.register(BEFV1DLogpValueOp)
def BEFV1D_SE_value_dispatch(op, **kwargs):
    def f(alpha,):
        return get_Squared_Error(op.T0, op.dt, alpha, op.TAmb, op.Td)
    return f

with pm.Model() as TempModel:
    alpha_ = pm.Normal('alpha', mu=0.04, sigma=0.01, initval=0.05)
    sigma = pm.HalfNormal('sigma', sigma=2)
    SE = pm.Deterministic('SE', BEFV1D_SE_value_op(alpha_))
    pm.Potential('res', -0.5*SE/sigma**2)

with TempModel:
    idata = pm.sampling_jax.sample_numpyro_nuts(200, tune=50, chains=2, chain_method='parallel', postprocessing_backend='cpu', postprocessing_chunks=2)

When I make the following changes, the NameError no longer occurs

def get_T_End(T0, dt, alpha, TAmb):
    #final, T_All = lax.scan(CalTemp, (T0, dt, alpha), TAmb)
    #return T_All[-1]
    return T0*dt*alpha*TAmb[0]

which indicates that jax.lax.scan is the contributing factor.

Anyone know what’s going on here or if a workaround exists?

Cheers.

Update number two…

The error only occurs when a jax.lax.scan() is used within the likelihood function AND the new postprocessing_chunks argument is specified in the sample_numpyro_nuts(). The error does not occur under the following situations:

  • jax.lax.scan() used but postprocessing_chunks not specified
  • jax.lax.scan() not used but postprocessin_chunks specified

Unfortunately, both of the above are important for my model so I’m hoping there is a workaround/fix available.

Cheers.

I encountered the exact same error when using cholesky + postprocessing_chunks. Here’s another MWE (without any custom JAX functions)

import pymc as pm
import pymc.sampling_jax
import numpy as np
import pytensor.tensor as pt

N = 10
X = np.arange(N)[:,None]
    
with pm.Model() as model: 
    ls = pm.Exponential('ls', lam=1)
    cov = pm.gp.cov.ExpQuad(1, ls)(X)
    f = pm.Deterministic('f', pt.slinalg.cholesky(cov))

    idata = pm.sampling_jax.sample_numpyro_nuts(50, tune=50, chains=2, chain_method='parallel',
                                                postprocessing_backend='cpu', postprocessing_chunks=2)

I have no idea why this happens, but this renders postprocessing_chunks useless for GP applications. (which is ironic given that postprocessing_chunks is meant to deal with memory issues) I suppose loops is the common thing between cholesky and jax.lax.scan. The code corresponding to postprocessing_chunks uses xmap, which is currently experimental, so perhaps it does not play nicely with jaxified graphs which involve loops (?)

Notably, if the covariance matrix is constant, then there is no issue.

Here’s a naive workaround - the idea is to rewrite pymc.sampling.jax._postprocess_samples in a way that doesn’t use xmap.

def _postprocess_samples(
    jax_fn: List[TensorVariable],
    raw_mcmc_samples: List[TensorVariable],
    postprocessing_backend: str,
    num_chunks: Optional[int] = None,
) -> List[TensorVariable]:
    if num_chunks is not None:
        # dims are vars, chains, draws, ...
        raw_mcmc_samples = jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
        f = jax.vmap(jax.vmap(jax_fn))
        draws = len(raw_mcmc_samples[0][0])
        segs = list(range(0, draws, draws // num_chunks)) + [draws]
        # dims are chunks, vars, chains, draws, ...
        outputs = [f(*[var_samples[:,i:j] for var_samples in raw_mcmc_samples])
                   for i, j in zip(segs[:-1], segs[1:])]
        # dims of var_chunks are chunks, chains, draws, ...
        return [jnp.concatenate(var_chunks, axis=1) for var_chunks in zip(*outputs)]
    else:
        return jax.vmap(jax.vmap(jax_fn))(
            *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
        )

Essentially, I’m looping through the chunks manually with basic loops. I’ll leave the proper solution to someone who is more well-versed with JAX.

It looks like this error only occurs with recent versions of JAX. I resolved the error by downgrading to 0.4.1 (from the current 0.4.8). For my use case (a GP), the downgrade didn’t appear to break anything else.

I get this issue a lot, and without any JAX scan. Does seem to be related to postprocessing_chunks though.