How can I terminate pm.sample when a convergence criterion is met (e.g. r_hat < 1.05)

Hi all,

I am wondering if the documentation page for Sample Callbacks is valid in PyMC version 4. I have tried reproducing the first example which simply terminates the sampler when the number of samples exceeds 100 (even if the draws argument is set to a higher value.) However, PyMC gives me a TypeError since trace is a NoneType for some reason. For convenience I have included a self-contained minimum working example below.

Minimum working example

import numpy as np
import pymc as pm
print(f'PyMC v{pm.__version__}')

# Generate artificial data
true_mu = 0.
true_sigma = 1.
data = np.random.normal(loc=true_mu, scale=true_sigma, size=1000)

# Define custom callback (taken from pymc.io/projects/examples/en/latest/howto/sampling_callback.html)
def my_callback(trace, draw):
    if len(trace) >= 100:
        raise KeyboardInterrupt()

with pm.Model() as model:
    mu = pm.Flat("mu")
    sigma = pm.Flat("sigma")
    obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=data)
    idata = pm.sample(draws=1000, tune=1000, chains=4, callback=my_callback)

Output:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
PyMC v4.0.0
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma]

 0.01% [1/8000 00:00<00:11 Sampling 4 chains, 0 divergences]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [13], in <cell line: 31>()
     33 sigma = pm.Flat("sigma")
     34 obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=data)
---> 35 idata = pm.sample(draws=1000, tune=1000, chains=4, callback=my_callback)

File ~/miniconda3/lib/python3.9/site-packages/pymc/sampling.py:607, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    605 _print_step_hierarchy(step)
    606 try:
--> 607     mtrace = _mp_sample(**sample_args, **parallel_args)
    608 except pickle.PickleError:
    609     _log.warning("Could not pickle model, sampling singlethreaded.")

File ~/miniconda3/lib/python3.9/site-packages/pymc/sampling.py:1547, in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, **kwargs)
   1544                     strace._add_warnings(draw.warnings)
   1546             if callback is not None:
-> 1547                 callback(trace=trace, draw=draw)
   1549 except ps.ParallelSamplingError as error:
   1550     strace = traces[error._chain - chain]

Input In [13], in my_callback(trace, draw)
      8 def my_callback(trace, draw):
----> 9     if len(trace) >= 100:
     10         raise KeyboardInterrupt()

TypeError: object of type 'NoneType' has no len()

Why is trace NoneType?

Additional context: My end goal is to write a callback function which terminates the sampler when the Gelman-Rubin statistic satsifies r_hat < 1.01, which is the subject of the second example on the Sample Callback page.

Hey,

I have encountered the same issue on my model (using PyMC v4.2.0) and could reproduce the error you got.

I do not have a solution to the problem, but what I’ve found is that the error only appears when using multiple chains. If you need the callback, a workaround (not a solution) may be to set chains=1 and run the sampling process multiple times.