Hi all,
I am wondering if the documentation page for Sample Callbacks is valid in PyMC version 4. I have tried reproducing the first example which simply terminates the sampler when the number of samples exceeds 100 (even if the draws
argument is set to a higher value.) However, PyMC gives me a TypeError since trace
is a NoneType
for some reason. For convenience I have included a self-contained minimum working example below.
Minimum working example
import numpy as np
import pymc as pm
print(f'PyMC v{pm.__version__}')
# Generate artificial data
true_mu = 0.
true_sigma = 1.
data = np.random.normal(loc=true_mu, scale=true_sigma, size=1000)
# Define custom callback (taken from pymc.io/projects/examples/en/latest/howto/sampling_callback.html)
def my_callback(trace, draw):
if len(trace) >= 100:
raise KeyboardInterrupt()
with pm.Model() as model:
mu = pm.Flat("mu")
sigma = pm.Flat("sigma")
obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=data)
idata = pm.sample(draws=1000, tune=1000, chains=4, callback=my_callback)
Output:
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
PyMC v4.0.0
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma]
0.01% [1/8000 00:00<00:11 Sampling 4 chains, 0 divergences]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [13], in <cell line: 31>()
33 sigma = pm.Flat("sigma")
34 obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=data)
---> 35 idata = pm.sample(draws=1000, tune=1000, chains=4, callback=my_callback)
File ~/miniconda3/lib/python3.9/site-packages/pymc/sampling.py:607, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
605 _print_step_hierarchy(step)
606 try:
--> 607 mtrace = _mp_sample(**sample_args, **parallel_args)
608 except pickle.PickleError:
609 _log.warning("Could not pickle model, sampling singlethreaded.")
File ~/miniconda3/lib/python3.9/site-packages/pymc/sampling.py:1547, in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, **kwargs)
1544 strace._add_warnings(draw.warnings)
1546 if callback is not None:
-> 1547 callback(trace=trace, draw=draw)
1549 except ps.ParallelSamplingError as error:
1550 strace = traces[error._chain - chain]
Input In [13], in my_callback(trace, draw)
8 def my_callback(trace, draw):
----> 9 if len(trace) >= 100:
10 raise KeyboardInterrupt()
TypeError: object of type 'NoneType' has no len()
Why is trace NoneType?
Additional context: My end goal is to write a callback function which terminates the sampler when the Gelman-Rubin statistic satsifies r_hat < 1.01
, which is the subject of the second example on the Sample Callback page.