I’ve been using the experimental BlackJax sampling via pm.sampling_jax.sample_blackjax_nuts() and want to avoid recompilation. The following example doesn’t seem to avoid recompilation pymc.set_data — PyMC 5.0.2 documentation - is there a way to make this work with JAX?

pm.set_data just sets a shared variable value. You can compile a logp function, and change the data after without recompiling it:

with pm.Model() as m:
data = pm.MutableData("data", np.ones(10))
x = pm.Normal("x")
y = pm.Normal("y", x, observed=data)
ip = m.initial_point()
logp = m.compile_logp()
print(logp(ip))
with m:
pm.set_data({"data": np.ones(100)})
print(logp(ip))

For JAX sampling via PyMC we don’t actually compile a whole PyTensor logp function and we inplace the shared variables, but we could easily provide the JAX function where the “shared variables” are explicit inputs of the logp instead of inplaced constants.