Essentially this is a matter of the PyMC samplers being stateful and there currently not being a standardized API to save/restore that state.
This includes things like instance attributes, but also random states (see Refactor step methods to use their own random stream · Issue #5797 · pymc-devs/pymc · GitHub).
I’m currently refactoring the trace backend to be natively McBackend-compatible, aiming to eventually delete pm.backends.BaseTrace
, pm.backends.ndarray.NDArray
in favor of defaulting to mcbackend.NumPyBackend
which supports sparse sampler stats (Support sparse sample stats · Issue #6194 · pymc-devs/pymc · GitHub).
With sparse sampler stats we can start saving the sampler state in sampler stats, for example by storing the mass matrix information in NUTS as a sampler stat every time it changed during tuning, or by emitting the current random state as a sampler stat every 100 iterations or so.
Having such “keyframes” in the trace could then be the starting point for properly restoring stateful samplers and resuming an MCMC.
With Refactoring towards `IBaseTrace` interfaces by michaelosthege · Pull Request #6475 · pymc-devs/pymc · GitHub I’m actually getting pretty close to optional McBackend support already. Any help with refactoring the step method interface (e.g. inclucing shape information in .stats_dtypes
, or taking care of Refactor step methods to use their own random stream · Issue #5797 · pymc-devs/pymc · GitHub) would be greatly appreciated!