Checkpoint / Resuming sampling with NUTS sampler

Hi all,

I want to use the NUTS sampler to sample a model but use some form of check-pointing. I can successfully sample from the model and obtain the trace results, so that itself is not an issue.

I am working on a server, but the maximum allotted time on the server is not enough to run the model once, therefore I need to run the sample model multiple times and let it continue to run until I have sufficient samples.

Therefore, I want to be able to:

  1. Sample for example 1000 tunes and save the trace result
  2. Use the trace result to sample from further for 1000 draws and save a second trace result
  3. Use the second trace result and draw 1000 more samples, and save that trace result
    etc. for as many draws as I need.

What would be the best approach for this? Is this feasible on PyMC alone? Any help would be appreciated.

As far as I could see, the most updated question on this was posted here, which still has not been resolved.

Hey – this is a really reasonable ask (though your particular situation sounds like a frustrating experience!) @cluhmann and I were just talking about this sort of “warm start” checkpointing. If you save (for each chain!)

  • The mass matrix
  • The step size
  • The last position

you should be able to resume sampling.

I work more in JAX now, and Welcome to Blackjax! is probably what I would reach for to implement this initially.

Specifically, something like https://blackjax-devs.github.io/sampling-book/models/change_of_variable_hmc.html:

warmup = blackjax.window_adaptation(blackjax.nuts, joint_logdensity)

# we use 4 chains for sampling
n_chains = 4
rng_key, init_key, warmup_key = jax.random.split(rng_key, 3)
init_keys = jax.random.split(init_key, n_chains)
init_params = jax.vmap(init_param_fn)(init_keys)

@jax.vmap
def call_warmup(seed, param):
    (initial_states, tuned_params), _ = warmup.run(seed, param, 1000)
    return initial_states, tuned_params

warmup_keys = jax.random.split(warmup_key, n_chains)
initial_states, tuned_params = jax.jit(call_warmup)(warmup_keys, init_params)

The tuned_params will have the mass matrix and step size. Then (from the same linked page)

n_samples = 1000
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop_multiple_chains(
    sample_key, initial_states, tuned_params, joint_logdensity, n_samples, n_chains
)

You would want to save the most recent states that were sampled (using something like jax.tree.map(lambda x: x[-1], states)), but note that tuned_params no longer changes.

1 Like