Hi,
context:
I am building a Hill equation curve fitting model. I am testing it with simulated data with only 4 measurements to estimate the parameters. I am able work with this using an hierachical model to group compounds together.
When sampling I get a numpyro warning below.
The sampling runs fast but it takes >30s to get the inference data back. The trace and summary stats all look fine.
Model:
# Coordinates
coords_model_v2 = {"cpd_id": cpd_id,
"moa_id": moa_id}
# model with pooled HDR, LDR, sd, shape parameters
with pm.Model(coords=coords_model_v2) as ED50_fit_v2:
# Data definition
rates = pm.Data('log10_rate', raw_data.rates, dims="obs_id")
# Global prior on observation sd
# model expect same noise for all compounds across tests and test plans
sd = pm.HalfNormal("sd", sigma=0.3)
# Moa Specific
# Model expects upper dose response to be close to 1 but not equal to (compatibility with beta response)
hdr = pm.Beta("hdr", alpha=10, beta=1.0, dims="moa_id")
# Model expects lower dose response to be close to zero but not equal to (compatibility with beta response)
ldr = pm.Beta("ldr", alpha=1.0, beta=10, dims="moa_id")
# model expect slopes from 0 to 8
shape = pm.Gamma("shape", alpha=3.0, beta=1.0, dims="moa_id") # model expect slopes from 0 to 8
# Compound specific priors model expect pIC50 from 0 to 7
logED50 = pm.Normal("logED50", mu=2.5, sigma=1, dims="cpd_id")
# Linear equation
linear = shape[moa_idx] * (rates - logED50[cpd_idx])
# Hill equation for Dose response / previous implementation with "beta" (hdr - ldr) and no ldf
mu = pm.Deterministic("Hill", ldr[moa_idx] + (hdr[moa_idx] - ldr[moa_idx] ) * pm.invlogit(-linear), dims="obs_id")
# Real observation with noise
response = pm.Normal("response",
mu=mu,
sigma=sd,
observed=raw_data.obs_response,
dims="obs_id"
)
pm.model_to_graphviz(ED50_fit_v2)
**sampling: **
with ED50_fit_v2:
trace_ED50_fit_v2 = pm.sample(tune=1000,
draws=1000,
chains=4,
target_accept=0.90,
random_seed=None,
nuts_sampler='numpyro',
return_inferencedata=True,
idata_kwargs=dict(log_likelihood = False)
)
Numpyro warning message:
2024-09-12 19:14:20.503302: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 8s:
%reduce = f64[4,1000,404]{2,1,0} reduce(f64[4,1000,1,404]{3,2,1,0} %broadcast.9, f64[] %constant.16), dimensions={2}, to_apply=%region_0.36, metadata={op_name="jit(process_fn)/jit(main)/reduce_prod[axes=(2,)]" source_file="/tmp/tmpvz9exf9w" source_line=17}
This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2024-09-12 19:14:36.278377: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 23.775190738s
Constant folding an instruction is taking > 8s:
%reduce = f64[4,1000,404]{2,1,0} reduce(f64[4,1000,1,404]{3,2,1,0} %broadcast.9, f64[] %constant.16), dimensions={2}, to_apply=%region_0.36, metadata={op_name="jit(process_fn)/jit(main)/reduce_prod[axes=(2,)]" source_file="/tmp/tmpvz9exf9w" source_line=17}
This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
What is the origin of this message and should I be modifying the model to resolve this?
Thank you in advanve for your feedback!.