Avoid recompilation when using BlackJax sampling


I’ve been using the experimental BlackJax sampling via pm.sampling_jax.sample_blackjax_nuts() and want to avoid recompilation. The following example doesn’t seem to avoid recompilation pymc.set_data — PyMC 5.0.2 documentation - is there a way to make this work with JAX?



Not an answe to your question, but I have had good experience using nutpie to sampling without recompilation: GitHub - pymc-devs/nutpie: Python wrapper for nuts-rs

using with_data nutpie/compile_pymc.py at e291c5b4f1346aefc60dfadefdf5e28bd9eba564 · pymc-devs/nutpie · GitHub

For blackjack, you would need to investigate where (if anywhere) in their API can you provide a jitted logp function and do the plumbing yourself.

Maybe @junpenglao can give some pointers

Does pm.set_data generates a new logprob function? If so there would be quite a bit of custom code involved to avoid recompiling in JAX.

pm.set_data just sets a shared variable value. You can compile a logp function, and change the data after without recompiling it:

with pm.Model() as m:
  data = pm.MutableData("data", np.ones(10))
  x = pm.Normal("x")
  y = pm.Normal("y", x, observed=data)

ip = m.initial_point()
logp = m.compile_logp()

with m:
  pm.set_data({"data": np.ones(100)})

For JAX sampling via PyMC we don’t actually compile a whole PyTensor logp function and we inplace the shared variables, but we could easily provide the JAX function where the “shared variables” are explicit inputs of the logp instead of inplaced constants.

1 Like