Using Callbacks in 3.11.2 to test for Divergences

I am trying to do something straight forward - I’d like to have a call back that a) checks the number of divergences across two chain, and b) keyboard interrupts if there are more than 10 across the two chains.
I am using NUTS .


def divergence_callback(trace,draw):
    print(trace)
    if trace.sample_stats['diverging'].sum().item()>10:
        raise KeyboardInterrupt()

The error is

<pymc3.backends.ndarray.NDArray object at 0x7f0bc8599b20>
'NDArray' object has no attribute 'sample_stats'

I thought that is because I have two chains but seems to have the same issue with a single chain. I guess the sample_stats attribute gets added after the chains have run. How can I go about this?

I’ve seen this https://www.mbrouns.com/posts/2020-01-26-pymc3-sampling-callback/but none of the callbacks work with 3.11.2

Specifically for this idea:

class MyCallback:
    def __init__(self, every=1000, max_rhat=1.05):
        self.every = every
        self.max_rhat = max_rhat
        self.traces = {}

    def __call__(self, trace, draw):
        if draw.tuning:
            return

        self.traces[draw.chain] = trace
        if len(trace) % self.every == 0:
            multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))
            if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:
                raise KeyboardInterrupt

the error is

data[name] = np.array(self.posterior_trace.get_sampler_stats(stat, combine=False))
could not broadcast input array from shape (1942,36) into shape (2000,36)
1 Like

The diagnostics is in draw:

class DivergentEarlyStopping:
    def __init__(self):
        self.count = 0

    def __call__(self, trace, draw):
        if draw.tuning:
            return

        self.count += int(draw.stats[0]['diverging'])
        if self.count > 10:
            raise KeyboardInterrupt()

callback = DivergentEarlyStopping()

with pm.Model():
    y = pm.Normal('y', mu=0., sigma=3.)
    x = pm.Normal('x', mu=0., sigma=pm.math.exp(y/2), shape=9)
    trace = pm.sample(1000, callback=callback)

This should work if you are sampling only using NUTS, if you are using CompoundStep you will need some further modification.

1 Like

Thanks - so if I have two chains, should I iterate through draw.stats[0] and draw.stats[1]? or should I implement as draw.stats[0] [‘diverging’].sum().item()

No need, draw.stats is a list containing stats from different sampler, the draw itself is from one of the chain. When you only have NUTS sampler it is a 1 element list.
In fact, this is more or less how we can display the number of divergence during sample: pymc3/sampling.py at 819f045ad36d1d8b18651528384972dc2bea8213 · pymc-devs/pymc3 · GitHub

Thanks - let me check and revert.