How can I terminate pm.sample when a convergence criterion is met (e.g. r_hat < 1.05)

Hi all,

I am wondering if the documentation page for Sample Callbacks is valid in PyMC version 4. I have tried reproducing the first example which simply terminates the sampler when the number of samples exceeds 100 (even if the draws argument is set to a higher value.) However, PyMC gives me a TypeError since trace is a NoneType for some reason. For convenience I have included a self-contained minimum working example below.

Minimum working example

import numpy as np
import pymc as pm
print(f'PyMC v{pm.__version__}')

# Generate artificial data
true_mu = 0.
true_sigma = 1.
data = np.random.normal(loc=true_mu, scale=true_sigma, size=1000)

# Define custom callback (taken from pymc.io/projects/examples/en/latest/howto/sampling_callback.html)
def my_callback(trace, draw):
    if len(trace) >= 100:
        raise KeyboardInterrupt()

with pm.Model() as model:
    mu = pm.Flat("mu")
    sigma = pm.Flat("sigma")
    obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=data)
    idata = pm.sample(draws=1000, tune=1000, chains=4, callback=my_callback)

Output:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
PyMC v4.0.0
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma]

 0.01% [1/8000 00:00<00:11 Sampling 4 chains, 0 divergences]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [13], in <cell line: 31>()
     33 sigma = pm.Flat("sigma")
     34 obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=data)
---> 35 idata = pm.sample(draws=1000, tune=1000, chains=4, callback=my_callback)

File ~/miniconda3/lib/python3.9/site-packages/pymc/sampling.py:607, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    605 _print_step_hierarchy(step)
    606 try:
--> 607     mtrace = _mp_sample(**sample_args, **parallel_args)
    608 except pickle.PickleError:
    609     _log.warning("Could not pickle model, sampling singlethreaded.")

File ~/miniconda3/lib/python3.9/site-packages/pymc/sampling.py:1547, in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, **kwargs)
   1544                     strace._add_warnings(draw.warnings)
   1546             if callback is not None:
-> 1547                 callback(trace=trace, draw=draw)
   1549 except ps.ParallelSamplingError as error:
   1550     strace = traces[error._chain - chain]

Input In [13], in my_callback(trace, draw)
      8 def my_callback(trace, draw):
----> 9     if len(trace) >= 100:
     10         raise KeyboardInterrupt()

TypeError: object of type 'NoneType' has no len()

Why is trace NoneType?

Additional context: My end goal is to write a callback function which terminates the sampler when the Gelman-Rubin statistic satsifies r_hat < 1.01, which is the subject of the second example on the Sample Callback page.

Hey,

I have encountered the same issue on my model (using PyMC v4.2.0) and could reproduce the error you got.

I do not have a solution to the problem, but what I’ve found is that the error only appears when using multiple chains. If you need the callback, a workaround (not a solution) may be to set chains=1 and run the sampling process multiple times.

have met the same issue, the example code doesn’t work.

The watermark at the bottom of the notebook shows it was last executed with PyMC 3.11, so it is not surprising it doesn’t work with 4.x or 5.x versions. We have an issue open to update the notebook, but so far nobody has gotten to it: sampling callback · Issue #112 · pymc-devs/pymc-examples · GitHub. I personally do not now how should one go on fixing it, but if someone is interested in fixing the notebook we can check if someone can help