Constant folding taking a long time after sampling in pymc4

I am experimenting with pymc4 and jax, and noticed that after sampling with sample_numpyro_nuts, it takes a looooong time until I get the trace back and I get these warnings:

2022-05-17 12:16:34.492894: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:61] Constant folding an instruction is taking > 1s:

  %reduce = f64[4,1000,9900]{2,1,0} reduce(f64[4,1000,9900,1]{3,2,1,0} %broadcast.13, f64[] %constant.11), dimensions={3}, to_apply=%region_0.38, metadata={op_name="jit(jax_funcified_fgraph)/jit(main)/reduce_prod[axes=(3,)]" source_file="/opt/conda/envs/pymc4_jax/lib/python3.10/site-packages/aesara/link/jax/dispatch.py" source_line=174}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2022-05-17 12:18:00.271621: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:61] Constant folding an instruction is taking > 2s:

  %reduce.2 = f64[4,1000,9900]{2,1,0} reduce(f64[4,1000,9900,1]{3,2,1,0} %broadcast.36, f64[] %constant.11), dimensions={3}, to_apply=%region_0.38, metadata={op_name="jit(jax_funcified_fgraph)/jit(main)/reduce_prod[axes=(3,)]" source_file="/opt/conda/envs/pymc4_jax/lib/python3.10/site-packages/aesara/link/jax/dispatch.py" source_line=174}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.

This is a relatively simple model. Also, constant folding is a compiler optimization, but this happens after sampling (after compilation, obviously.)…

Any ideas? I am using pymc 4.0.0.b5 installed through conda-forge

First, I would recommend upgrading to b6 if you can. Second, maybe @ricardoV94 has some idea about what is happening post-sampling?

1 Like

Seems to be a JAX compilation issue, nothing we can do on our end

1 Like

To explain, this is happening after sampling, where we recompute any deterministics or transformed variables from the trace results.

I too am noticing this issue of long Constant folding an instruction is taking > 1s time, post sampling.

pytensor                  2.25.2
pymc                      5.16.2
numpyro                   0.16.1

My model involves a scan routine, and the sampling time seems to be brought down by nuts_sampler="numpyro" from ~30min to ~20s (awesome!). However, the resolution of the “constant folding” slow_operation_alarm takes ~3min.

Is the status still the same regarding possible ways to address this issue?

Not much a user can do on their end I’m afraid.

You could try with nutpie, that doesn’t compute the deterministics in batches after sampling, but while it is sampling one after the other. Maybe that doesn’t trigger the bad jax constant folding?

You can control if postprocessing is done with vmap or scan as well: pymc/pymc/sampling/jax.py at main · pymc-devs/pymc · GitHub

If that’s where it’s happening.

I’m still not sure it makes sense to jit it for a single function call