Resuming sampling from a previous trace

Hi. Is there a way to resume sampling from a previous trace / inference data (or other state from a previous sampling run)? I’d like to be able to draw more samples, without having to draw any tuning samples.

I’m using the NUTS sampler and I’ve tried the approach @junpenglao outlines in this discussion, but the methods he outlines fail for me:

  • The new model approach fails because models no longer have a bijection attribute
  • The same model approach draws identical samples that do not change

Is there a way to resuming sampling from a previous trace / inference data? Ideally, I’m looking for a solution that I could use for any sampling methods, not just NUTS. If not, is this something that pymc is planning to support? It seems like a fairly common use case would be to draw some samples, analyze the statistics / convergence statistics and then draw more samples

Thanks,
Laurence

Welcome!

You might want to check out related discussion here.

Hi @cluhmann. Thanks for the link. I’ve tried the solutions outlined in that discussion, but they didn’t work:

  • The iter_sample function no longer exists
  • If I use mcbackend to draw samples, then use it to draw samples without tuning, every sample in the second set of draws is identical

Interesting. Maybe @michaelosthege has some suggestions?

Essentially this is a matter of the PyMC samplers being stateful and there currently not being a standardized API to save/restore that state.
This includes things like instance attributes, but also random states (see Refactor step methods to use their own random stream · Issue #5797 · pymc-devs/pymc · GitHub).

I’m currently refactoring the trace backend to be natively McBackend-compatible, aiming to eventually delete pm.backends.BaseTrace, pm.backends.ndarray.NDArray in favor of defaulting to mcbackend.NumPyBackend which supports sparse sampler stats (Support sparse sample stats · Issue #6194 · pymc-devs/pymc · GitHub).

With sparse sampler stats we can start saving the sampler state in sampler stats, for example by storing the mass matrix information in NUTS as a sampler stat every time it changed during tuning, or by emitting the current random state as a sampler stat every 100 iterations or so.
Having such “keyframes” in the trace could then be the starting point for properly restoring stateful samplers and resuming an MCMC.

With Refactoring towards `IBaseTrace` interfaces by michaelosthege · Pull Request #6475 · pymc-devs/pymc · GitHub I’m actually getting pretty close to optional McBackend support already. Any help with refactoring the step method interface (e.g. inclucing shape information in .stats_dtypes, or taking care of Refactor step methods to use their own random stream · Issue #5797 · pymc-devs/pymc · GitHub) would be greatly appreciated!

2 Likes

Thank you for your explanation. As the PyMC package is now, is there a way to resume sampling from a previous trace?

Thanks.

1 Like

Thank you for the great explanation and the work you do on this issue. As the solution depends on several issues/PRs, it’s hard for non-developers like me to understand the state of this issue, but perhaps you could post in this thread when there is a PR (or several that in combination) provides restarting functionality?

Kind regards, Hans Ekbrand

Optional McBackend support was released in v5.1.1, so the next steps are:

  • Start testing McBackend with your models & contribute feedback and bugfixes!
  • Modify simple step methods to periodically emit state (random number generator state, tuned hyperparameters) as sampler stats
  • Add optional kwargs to the step method to restore them, like restored_step = pm.Metropolis(..., **idata.sample_stats.sel(draw=12_000))

For simplicitly this should first be done for a simple non-HMC sampler, because this is much easier to dig into & test.

AFAIK this is currently not a priority for any of the core developers, so I would like to encourage you to get involved!

Was there any progress on this since this post? Can this be used more reliably with HMC? Thank you!

@juststarted I didn’t have the time (I have other priorities), but I am not aware of problems either. The test_mcbackend.py suite runs with every CI in the PyMC project, for example.

Thanks for the answer! I’ll test it on my model with NUTS and see if it works then. Do you know if there’s a doc somewhere showing how to store a run and restart it from where it left off with pymc? I couldn’t find anything on the forum and the mcbackend documentation seems to be library neutral.

Okay, after some tests here is what I am running:

import arviz
import numpy as np
import pymc as pm
import mcbackend as mcb
from clickhouse_driver import Client


def define_simple_model():
    seconds = np.linspace(0, 5)
    observations = np.random.normal(0.5 + np.random.uniform(size=3)[:, None] * seconds[None, :])
    with pm.Model(
        coords={
            "condition": ["A", "B", "C"],
        }
    ) as pmodel:
        x = pm.ConstantData("seconds", seconds, dims="time")
        a = pm.Normal("scalar")
        b = pm.Uniform("vector", dims="condition")
        pm.Deterministic("matrix", a + b[:, None] * x[None, :], dims=("condition", "time"))
        obs = pm.MutableData("obs", observations, dims=("condition", "time"))
        pm.Normal("L", pmodel["matrix"], observed=obs, dims=("condition", "time"))
        
    return pmodel


if __name__=='__main__':
    
    simple_model = define_simple_model()
    backend = mcb.NumPyBackend()
    with simple_model:
        trace = pm.sample(
            trace=backend
        )
        
    print(trace)

This raises the following error:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [scalar, vector]
Traceback (most recent call last):████████████████████████████████████████████████████████████| 100.00% [8000/8000 00:01<00:00 Sampling 4 chains, 0 divergences]
  File "/mnt/c/Users/faust/Dropbox/Tubingen/joint_learning/model/test_clickhouse_backend.py", line 35, in <module>
    trace = pm.sample(
            ^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 702, in sample
    return _sample_return(
           ^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 736, in _sample_return
    mtrace = MultiTrace(traces)[:length]
             ~~~~~~~~~~~~~~~~~~^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/base.py", line 370, in __getitem__
    return self._slice(idx)
           ^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/base.py", line 537, in _slice
    new_traces = [trace._slice(slice) for trace in self._straces.values()]
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/base.py", line 537, in <listcomp>
    new_traces = [trace._slice(slice) for trace in self._straces.values()]
                  ^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 194, in _slice
    stats = self._chain.get_stats_at(i, stat_names=snames)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/numpy.py", line 106, in get_stats_at
    return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/numpy.py", line 106, in <dictcomp>
    return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}
                              ~~~~~~~~~~~~~~~^^^^^
IndexError: index 1000 is out of bounds for axis 0 with size 1000

Ultimately, I am interested in running with the ClickHouse backend. But when I replace the relevant part:

if __name__=='__main__':
    
    simple_model = define_simple_model()
    
    ch_client = Client("localhost")
    # Check that it is defined properly
    print(ch_client.execute('SHOW DATABASES'))
    backend = mcb.ClickHouseBackend(ch_client)
    
    with simple_model:
        trace = pm.sample(
            trace=backend
        )
        
    print(trace)

another error is raised:

[('INFORMATION_SCHEMA',), ('default',), ('information_schema',), ('system',)]
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Dimensionality of sample stat 'sampler_0__warning' is undefined. Assuming ndim=0.
Traceback (most recent call last):
  File "/mnt/c/Users/faust/Dropbox/Tubingen/joint_learning/model/test_clickhouse_backend.py", line 37, in <module>
    trace = pm.sample(
            ^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 623, in sample
    run, traces = init_traces(
                  ^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/__init__.py", line 127, in init_traces
    return init_chain_adapters(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 278, in init_chain_adapters
    adapters = [
               ^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 280, in <listcomp>
    chain=run.init_chain(chain_number=chain_number),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/clickhouse.py", line 325, in init_chain
    create_chain_table(self._client, cmeta, self.meta)
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/clickhouse.py", line 98, in create_chain_table
    columns.append(column_spec_for(var, is_stat=True))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/clickhouse.py", line 65, in column_spec_for
    raise KeyError(
KeyError: "Don't know how to store dtype object of 'sampler_0__warning' (is_stat=True) in ClickHouse."

I am using PyMC v.5.3.0. Any help appreciated!

EDIT:

I fixed the error by adding the following lines in “mcbackend/backends/clickhouse.py” on line 101:

if var.dtype=='object':
    var.dtype='str'

As it turns out, this only affects the stats “sampler_0__warning”:

Variable(name='tune', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__depth', dtype='int64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__step_size', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__tune', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__mean_tree_accept', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__step_size_bar', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__tree_size', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__diverging', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__energy_error', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__energy', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__max_energy_error', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__model_logp', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__process_time_diff', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__perf_counter_diff', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__perf_counter_start', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__largest_eigval', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__smallest_eigval', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__index_in_trajectory', dtype='int64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__reached_max_treedepth', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__warning', dtype='object', shape=[], dims=[], is_deterministic=False, undefined_ndim=True)
False 

I ran the code above a few time with the ClickHouse backend (with the fix to mcbackend). Then I ran this code:

ch_client = Client("localhost")
backend = mcbackend.ClickHouseBackend(ch_client)

print("list of runs:\n", backend.get_runs())

# Fetch a single run from the database (downloads just metadata)
run = backend.get_run("6F6PW")

# Convert everything to `InferenceData`
idata = run.to_inferencedata()

print(idata)

Which, as expected, prints out the following:

list of runs:
                             created_at   
rid                                      
99PHY 2023-05-19 02:54:46.005139+00:00  \
TREHW 2023-05-19 02:56:15.987728+00:00   
1PQK6 2023-05-19 02:56:56.700751+00:00   
YKT3N 2023-05-19 03:34:37.115745+00:00   
WLKW4 2023-05-19 03:43:04.277188+00:00   
9HHAC 2023-05-19 03:49:53.529272+00:00   
1P93X 2023-05-19 03:51:50.767925+00:00   
Q8LLA 2023-05-19 03:54:46.947725+00:00   
DPP94 2023-05-19 03:56:24.189411+00:00   
9CFX4 2023-05-19 03:58:19.232365+00:00   
C3AZB 2023-05-19 04:01:53.742057+00:00   
A36Y7 2023-05-19 04:05:15.148334+00:00   
M4LND 2023-05-19 04:06:23.221856+00:00   
6F6PW 2023-05-19 04:44:46.898726+00:00   

                                                   proto  
rid                                                       
99PHY  RunMeta(rid='99PHY', variables=[Variable(name=...  
TREHW  RunMeta(rid='TREHW', variables=[Variable(name=...  
1PQK6  RunMeta(rid='1PQK6', variables=[Variable(name=...  
YKT3N  RunMeta(rid='YKT3N', variables=[Variable(name=...  
WLKW4  RunMeta(rid='WLKW4', variables=[Variable(name=...  
9HHAC  RunMeta(rid='9HHAC', variables=[Variable(name=...  
1P93X  RunMeta(rid='1P93X', variables=[Variable(name=...  
Q8LLA  RunMeta(rid='Q8LLA', variables=[Variable(name=...  
DPP94  RunMeta(rid='DPP94', variables=[Variable(name=...  
9CFX4  RunMeta(rid='9CFX4', variables=[Variable(name=...  
C3AZB  RunMeta(rid='C3AZB', variables=[Variable(name=...  
A36Y7  RunMeta(rid='A36Y7', variables=[Variable(name=...  
M4LND  RunMeta(rid='M4LND', variables=[Variable(name=...  
6F6PW  RunMeta(rid='6F6PW', variables=[Variable(name=...  
Inference data with groups:
	> posterior
	> sample_stats
	> observed_data
	> constant_data

Warmup iterations saved (warmup_*).
/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/arviz/data/base.py:221: UserWarning: More chains (3) than draws (0). Passed array should have shape (chains, draws, *shape)
  warnings.warn(
/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/arviz/data/base.py:221: UserWarning: More chains (3) than draws (0). Passed array should have shape (chains, draws, *shape)
  warnings.warn(
/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/arviz/data/base.py:221: UserWarning: More chains (3) than draws (0). Passed array should have shape (chains, draws, *shape)
  warnings.warn(

This still has a problem, namely the “warmup_sample_stats” group of the idata contains a bunch of fields, but they are all empty.

@juststarted hey, sorry for not replying (I didn’t check discourse in a while).
In the meantime (PyMC v5.7.0) there was a bugfix that should apply to at least one of the problems you described.

If you run into other problems, please open an Issue in GitHub - pymc-devs/mcbackend: A backend for storing MCMC draws. and tag me, then I won’t overlook it :slight_smile: