Get problem to resume the NUTS sampler

Hi!

I am trying to run a NUTS sampler. I want to save the model at some point and resume the sampler later to get more points so I use pm.save_trace and pm.load_trace. It works when I run the sampler the first time and save/load the trace. However, when I try to run pm.sample() with trace=load_trace, I get an error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_63722/2393317151.py in <module>
     68     load_trace = pm.load_trace(directory= savefile)
     69     points = load_trace.point(-1)
---> 70     trace = pm.sample(draws=0, tune=20, discard_tuned_samples=False, chains=5, 
     71                         cores=5, return_inferencedata=False, trace = load_trace,
     72                         start=load_trace.point(-1), mp_ctx=ctx)

~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    557         _print_step_hierarchy(step)
    558         try:
--> 559             trace = _mp_sample(**sample_args, **parallel_args)
    560         except pickle.PickleError:
    561             _log.warning("Could not pickle model, sampling singlethreaded.")

~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
   1447     for idx in range(chain, chain + chains):
   1448         if trace is not None:
-> 1449             strace = _choose_backend(copy(trace), idx, model=model)
   1450         else:
   1451             strace = _choose_backend(None, idx, model=model)

~/.conda/envs/ML/lib/python3.8/copy.py in copy(x)
     90         reductor = getattr(x, "__reduce_ex__", None)
     91         if reductor is not None:
---> 92             rv = reductor(4)
     93         else:
     94             reductor = getattr(x, "__reduce__", None)

~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/backends/base.py in __getattr__(self, name)
    363                 )
    364             return self.get_values(name)
--> 365         if name in self.stat_names:
    366             return self.get_sampler_stats(name)
    367         raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))

~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/backends/base.py in stat_names(self)
    382         sampler_vars = [s.sampler_vars for s in self._straces.values()]
    383         if not all(svars == sampler_vars[0] for svars in sampler_vars):
--> 384             raise ValueError("Inividual chains contain different sampler stats")
    385         names = set()
    386         for trace in self._straces.values():

ValueError: Inividual chains contain different sampler stats

I run 5 chains on 5 cores. It seems like the type of statistic variables are different for the chains in the saved trace. When I ran:

print(load_trace._straces[2].sampler_vars)
print(load_trace._straces[0].sampler_vars)

the outputs are:

[{'depth': 'int64', 'step_size': 'float64', 'tune': 'bool', 'mean_tree_accept': 'float64', 'step_size_bar': 'float64', 'tree_size': 'float64', 'diverging': 'bool', 'energy_error': 'float64', 'energy': 'float64', 'max_energy_error': 'float64', 'model_logp': 'float64', 'process_time_diff': 'float64', 'perf_counter_diff': 'float64', 'perf_counter_start': 'float64'}]
[{'depth': <class 'numpy.int64'>, 'step_size': <class 'numpy.float64'>, 'tune': <class 'bool'>, 'mean_tree_accept': <class 'numpy.float64'>, 'step_size_bar': <class 'numpy.float64'>, 'tree_size': <class 'numpy.float64'>, 'diverging': <class 'bool'>, 'energy_error': <class 'numpy.float64'>, 'energy': <class 'numpy.float64'>, 'max_energy_error': <class 'numpy.float64'>, 'model_logp': <class 'numpy.float64'>, 'process_time_diff': <class 'numpy.float64'>, 'perf_counter_diff': <class 'numpy.float64'>, 'perf_counter_start': <class 'numpy.float64'>}]

Here is the part of code I load the trace and try to get some more sampler points:

    for _ in range(train_iter):
        with pm.Model() as pymodel:

            para_value = pm.Uniform('para_value', lower=0, upper = 1, shape=(2,))
            like = pm.Potential('like', logp1(para_value)) # logp1 is some function to calculate the log_likelihood
            
            load_trace = pm.load_trace(directory= savefile) 

            trace = pm.sample(ndraws, tune=nburn, discard_tuned_samples=False, chains=chains_num, 
                                cores=cores_num, return_inferencedata=False, trace = load_trace, start=load_trace.point(-1), mp_ctx=ctx)
            new_iternum = trace.get_values(varname='para_value').shape[0]
            logger.info(f'new_iternum: {new_iternum}.')

        pm.save_trace(trace, directory=savefile, overwrite=True)

So does anyone know what happened here?