Resuming sampling from a previous trace

Okay, after some tests here is what I am running:

import arviz
import numpy as np
import pymc as pm
import mcbackend as mcb
from clickhouse_driver import Client


def define_simple_model():
    seconds = np.linspace(0, 5)
    observations = np.random.normal(0.5 + np.random.uniform(size=3)[:, None] * seconds[None, :])
    with pm.Model(
        coords={
            "condition": ["A", "B", "C"],
        }
    ) as pmodel:
        x = pm.ConstantData("seconds", seconds, dims="time")
        a = pm.Normal("scalar")
        b = pm.Uniform("vector", dims="condition")
        pm.Deterministic("matrix", a + b[:, None] * x[None, :], dims=("condition", "time"))
        obs = pm.MutableData("obs", observations, dims=("condition", "time"))
        pm.Normal("L", pmodel["matrix"], observed=obs, dims=("condition", "time"))
        
    return pmodel


if __name__=='__main__':
    
    simple_model = define_simple_model()
    backend = mcb.NumPyBackend()
    with simple_model:
        trace = pm.sample(
            trace=backend
        )
        
    print(trace)

This raises the following error:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [scalar, vector]
Traceback (most recent call last):████████████████████████████████████████████████████████████| 100.00% [8000/8000 00:01<00:00 Sampling 4 chains, 0 divergences]
  File "/mnt/c/Users/faust/Dropbox/Tubingen/joint_learning/model/test_clickhouse_backend.py", line 35, in <module>
    trace = pm.sample(
            ^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 702, in sample
    return _sample_return(
           ^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 736, in _sample_return
    mtrace = MultiTrace(traces)[:length]
             ~~~~~~~~~~~~~~~~~~^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/base.py", line 370, in __getitem__
    return self._slice(idx)
           ^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/base.py", line 537, in _slice
    new_traces = [trace._slice(slice) for trace in self._straces.values()]
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/base.py", line 537, in <listcomp>
    new_traces = [trace._slice(slice) for trace in self._straces.values()]
                  ^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 194, in _slice
    stats = self._chain.get_stats_at(i, stat_names=snames)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/numpy.py", line 106, in get_stats_at
    return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/numpy.py", line 106, in <dictcomp>
    return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}
                              ~~~~~~~~~~~~~~~^^^^^
IndexError: index 1000 is out of bounds for axis 0 with size 1000

Ultimately, I am interested in running with the ClickHouse backend. But when I replace the relevant part:

if __name__=='__main__':
    
    simple_model = define_simple_model()
    
    ch_client = Client("localhost")
    # Check that it is defined properly
    print(ch_client.execute('SHOW DATABASES'))
    backend = mcb.ClickHouseBackend(ch_client)
    
    with simple_model:
        trace = pm.sample(
            trace=backend
        )
        
    print(trace)

another error is raised:

[('INFORMATION_SCHEMA',), ('default',), ('information_schema',), ('system',)]
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Dimensionality of sample stat 'sampler_0__warning' is undefined. Assuming ndim=0.
Traceback (most recent call last):
  File "/mnt/c/Users/faust/Dropbox/Tubingen/joint_learning/model/test_clickhouse_backend.py", line 37, in <module>
    trace = pm.sample(
            ^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 623, in sample
    run, traces = init_traces(
                  ^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/__init__.py", line 127, in init_traces
    return init_chain_adapters(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 278, in init_chain_adapters
    adapters = [
               ^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 280, in <listcomp>
    chain=run.init_chain(chain_number=chain_number),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/clickhouse.py", line 325, in init_chain
    create_chain_table(self._client, cmeta, self.meta)
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/clickhouse.py", line 98, in create_chain_table
    columns.append(column_spec_for(var, is_stat=True))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/clickhouse.py", line 65, in column_spec_for
    raise KeyError(
KeyError: "Don't know how to store dtype object of 'sampler_0__warning' (is_stat=True) in ClickHouse."

I am using PyMC v.5.3.0. Any help appreciated!

EDIT:

I fixed the error by adding the following lines in “mcbackend/backends/clickhouse.py” on line 101:

if var.dtype=='object':
    var.dtype='str'

As it turns out, this only affects the stats “sampler_0__warning”:

Variable(name='tune', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__depth', dtype='int64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__step_size', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__tune', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__mean_tree_accept', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__step_size_bar', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__tree_size', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__diverging', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__energy_error', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__energy', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__max_energy_error', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__model_logp', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__process_time_diff', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__perf_counter_diff', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__perf_counter_start', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__largest_eigval', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__smallest_eigval', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__index_in_trajectory', dtype='int64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__reached_max_treedepth', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__warning', dtype='object', shape=[], dims=[], is_deterministic=False, undefined_ndim=True)
False 

I ran the code above a few time with the ClickHouse backend (with the fix to mcbackend). Then I ran this code:

ch_client = Client("localhost")
backend = mcbackend.ClickHouseBackend(ch_client)

print("list of runs:\n", backend.get_runs())

# Fetch a single run from the database (downloads just metadata)
run = backend.get_run("6F6PW")

# Convert everything to `InferenceData`
idata = run.to_inferencedata()

print(idata)

Which, as expected, prints out the following:

list of runs:
                             created_at   
rid                                      
99PHY 2023-05-19 02:54:46.005139+00:00  \
TREHW 2023-05-19 02:56:15.987728+00:00   
1PQK6 2023-05-19 02:56:56.700751+00:00   
YKT3N 2023-05-19 03:34:37.115745+00:00   
WLKW4 2023-05-19 03:43:04.277188+00:00   
9HHAC 2023-05-19 03:49:53.529272+00:00   
1P93X 2023-05-19 03:51:50.767925+00:00   
Q8LLA 2023-05-19 03:54:46.947725+00:00   
DPP94 2023-05-19 03:56:24.189411+00:00   
9CFX4 2023-05-19 03:58:19.232365+00:00   
C3AZB 2023-05-19 04:01:53.742057+00:00   
A36Y7 2023-05-19 04:05:15.148334+00:00   
M4LND 2023-05-19 04:06:23.221856+00:00   
6F6PW 2023-05-19 04:44:46.898726+00:00   

                                                   proto  
rid                                                       
99PHY  RunMeta(rid='99PHY', variables=[Variable(name=...  
TREHW  RunMeta(rid='TREHW', variables=[Variable(name=...  
1PQK6  RunMeta(rid='1PQK6', variables=[Variable(name=...  
YKT3N  RunMeta(rid='YKT3N', variables=[Variable(name=...  
WLKW4  RunMeta(rid='WLKW4', variables=[Variable(name=...  
9HHAC  RunMeta(rid='9HHAC', variables=[Variable(name=...  
1P93X  RunMeta(rid='1P93X', variables=[Variable(name=...  
Q8LLA  RunMeta(rid='Q8LLA', variables=[Variable(name=...  
DPP94  RunMeta(rid='DPP94', variables=[Variable(name=...  
9CFX4  RunMeta(rid='9CFX4', variables=[Variable(name=...  
C3AZB  RunMeta(rid='C3AZB', variables=[Variable(name=...  
A36Y7  RunMeta(rid='A36Y7', variables=[Variable(name=...  
M4LND  RunMeta(rid='M4LND', variables=[Variable(name=...  
6F6PW  RunMeta(rid='6F6PW', variables=[Variable(name=...  
Inference data with groups:
	> posterior
	> sample_stats
	> observed_data
	> constant_data

Warmup iterations saved (warmup_*).
/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/arviz/data/base.py:221: UserWarning: More chains (3) than draws (0). Passed array should have shape (chains, draws, *shape)
  warnings.warn(
/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/arviz/data/base.py:221: UserWarning: More chains (3) than draws (0). Passed array should have shape (chains, draws, *shape)
  warnings.warn(
/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/arviz/data/base.py:221: UserWarning: More chains (3) than draws (0). Passed array should have shape (chains, draws, *shape)
  warnings.warn(

This still has a problem, namely the “warmup_sample_stats” group of the idata contains a bunch of fields, but they are all empty.