I am experimenting with pymc4 and jax, and noticed that after sampling with sample_numpyro_nuts, it takes a looooong time until I get the trace back and I get these warnings:
2022-05-17 12:16:34.492894: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:61] Constant folding an instruction is taking > 1s:
%reduce = f64[4,1000,9900]{2,1,0} reduce(f64[4,1000,9900,1]{3,2,1,0} %broadcast.13, f64[] %constant.11), dimensions={3}, to_apply=%region_0.38, metadata={op_name="jit(jax_funcified_fgraph)/jit(main)/reduce_prod[axes=(3,)]" source_file="/opt/conda/envs/pymc4_jax/lib/python3.10/site-packages/aesara/link/jax/dispatch.py" source_line=174}
This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2022-05-17 12:18:00.271621: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:61] Constant folding an instruction is taking > 2s:
%reduce.2 = f64[4,1000,9900]{2,1,0} reduce(f64[4,1000,9900,1]{3,2,1,0} %broadcast.36, f64[] %constant.11), dimensions={3}, to_apply=%region_0.38, metadata={op_name="jit(jax_funcified_fgraph)/jit(main)/reduce_prod[axes=(3,)]" source_file="/opt/conda/envs/pymc4_jax/lib/python3.10/site-packages/aesara/link/jax/dispatch.py" source_line=174}
This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
This is a relatively simple model. Also, constant folding is a compiler optimization, but this happens after sampling (after compilation, obviously.)…
Any ideas? I am using pymc 4.0.0.b5 installed through conda-forge