Constant folding taking a long time after sampling in pymc4

I am experimenting with pymc4 and jax, and noticed that after sampling with sample_numpyro_nuts, it takes a looooong time until I get the trace back and I get these warnings:

2022-05-17 12:16:34.492894: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:61] Constant folding an instruction is taking > 1s:

  %reduce = f64[4,1000,9900]{2,1,0} reduce(f64[4,1000,9900,1]{3,2,1,0} %broadcast.13, f64[] %constant.11), dimensions={3}, to_apply=%region_0.38, metadata={op_name="jit(jax_funcified_fgraph)/jit(main)/reduce_prod[axes=(3,)]" source_file="/opt/conda/envs/pymc4_jax/lib/python3.10/site-packages/aesara/link/jax/dispatch.py" source_line=174}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2022-05-17 12:18:00.271621: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:61] Constant folding an instruction is taking > 2s:

  %reduce.2 = f64[4,1000,9900]{2,1,0} reduce(f64[4,1000,9900,1]{3,2,1,0} %broadcast.36, f64[] %constant.11), dimensions={3}, to_apply=%region_0.38, metadata={op_name="jit(jax_funcified_fgraph)/jit(main)/reduce_prod[axes=(3,)]" source_file="/opt/conda/envs/pymc4_jax/lib/python3.10/site-packages/aesara/link/jax/dispatch.py" source_line=174}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.

This is a relatively simple model. Also, constant folding is a compiler optimization, but this happens after sampling (after compilation, obviously.)…

Any ideas? I am using pymc 4.0.0.b5 installed through conda-forge

First, I would recommend upgrading to b6 if you can. Second, maybe @ricardoV94 has some idea about what is happening post-sampling?

1 Like

Seems to be a JAX compilation issue, nothing we can do on our end

1 Like

To explain, this is happening after sampling, where we recompute any deterministics or transformed variables from the trace results.