Hi!
I am trying to run a NUTS sampler. I want to save the model at some point and resume the sampler later to get more points so I use pm.save_trace
and pm.load_trace
. It works when I run the sampler the first time and save/load the trace. However, when I try to run pm.sample()
with trace=load_trace
, I get an error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_63722/2393317151.py in <module>
68 load_trace = pm.load_trace(directory= savefile)
69 points = load_trace.point(-1)
---> 70 trace = pm.sample(draws=0, tune=20, discard_tuned_samples=False, chains=5,
71 cores=5, return_inferencedata=False, trace = load_trace,
72 start=load_trace.point(-1), mp_ctx=ctx)
~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
557 _print_step_hierarchy(step)
558 try:
--> 559 trace = _mp_sample(**sample_args, **parallel_args)
560 except pickle.PickleError:
561 _log.warning("Could not pickle model, sampling singlethreaded.")
~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
1447 for idx in range(chain, chain + chains):
1448 if trace is not None:
-> 1449 strace = _choose_backend(copy(trace), idx, model=model)
1450 else:
1451 strace = _choose_backend(None, idx, model=model)
~/.conda/envs/ML/lib/python3.8/copy.py in copy(x)
90 reductor = getattr(x, "__reduce_ex__", None)
91 if reductor is not None:
---> 92 rv = reductor(4)
93 else:
94 reductor = getattr(x, "__reduce__", None)
~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/backends/base.py in __getattr__(self, name)
363 )
364 return self.get_values(name)
--> 365 if name in self.stat_names:
366 return self.get_sampler_stats(name)
367 raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
~/.conda/envs/ML/lib/python3.8/site-packages/pymc3/backends/base.py in stat_names(self)
382 sampler_vars = [s.sampler_vars for s in self._straces.values()]
383 if not all(svars == sampler_vars[0] for svars in sampler_vars):
--> 384 raise ValueError("Inividual chains contain different sampler stats")
385 names = set()
386 for trace in self._straces.values():
ValueError: Inividual chains contain different sampler stats
I run 5 chains on 5 cores. It seems like the type of statistic variables are different for the chains in the saved trace. When I ran:
print(load_trace._straces[2].sampler_vars)
print(load_trace._straces[0].sampler_vars)
the outputs are:
[{'depth': 'int64', 'step_size': 'float64', 'tune': 'bool', 'mean_tree_accept': 'float64', 'step_size_bar': 'float64', 'tree_size': 'float64', 'diverging': 'bool', 'energy_error': 'float64', 'energy': 'float64', 'max_energy_error': 'float64', 'model_logp': 'float64', 'process_time_diff': 'float64', 'perf_counter_diff': 'float64', 'perf_counter_start': 'float64'}]
[{'depth': <class 'numpy.int64'>, 'step_size': <class 'numpy.float64'>, 'tune': <class 'bool'>, 'mean_tree_accept': <class 'numpy.float64'>, 'step_size_bar': <class 'numpy.float64'>, 'tree_size': <class 'numpy.float64'>, 'diverging': <class 'bool'>, 'energy_error': <class 'numpy.float64'>, 'energy': <class 'numpy.float64'>, 'max_energy_error': <class 'numpy.float64'>, 'model_logp': <class 'numpy.float64'>, 'process_time_diff': <class 'numpy.float64'>, 'perf_counter_diff': <class 'numpy.float64'>, 'perf_counter_start': <class 'numpy.float64'>}]
Here is the part of code I load the trace and try to get some more sampler points:
for _ in range(train_iter):
with pm.Model() as pymodel:
para_value = pm.Uniform('para_value', lower=0, upper = 1, shape=(2,))
like = pm.Potential('like', logp1(para_value)) # logp1 is some function to calculate the log_likelihood
load_trace = pm.load_trace(directory= savefile)
trace = pm.sample(ndraws, tune=nburn, discard_tuned_samples=False, chains=chains_num,
cores=cores_num, return_inferencedata=False, trace = load_trace, start=load_trace.point(-1), mp_ctx=ctx)
new_iternum = trace.get_values(varname='para_value').shape[0]
logger.info(f'new_iternum: {new_iternum}.')
pm.save_trace(trace, directory=savefile, overwrite=True)
So does anyone know what happened here?