I am trying to do something straight forward - I’d like to have a call back that a) checks the number of divergences across two chain, and b) keyboard interrupts if there are more than 10 across the two chains.
I am using NUTS .
def divergence_callback(trace,draw):
print(trace)
if trace.sample_stats['diverging'].sum().item()>10:
raise KeyboardInterrupt()
The error is
<pymc3.backends.ndarray.NDArray object at 0x7f0bc8599b20>
'NDArray' object has no attribute 'sample_stats'
I thought that is because I have two chains but seems to have the same issue with a single chain. I guess the sample_stats attribute gets added after the chains have run. How can I go about this?
I’ve seen this https://www.mbrouns.com/posts/2020-01-26-pymc3-sampling-callback/but none of the callbacks work with 3.11.2
Specifically for this idea:
class MyCallback:
def __init__(self, every=1000, max_rhat=1.05):
self.every = every
self.max_rhat = max_rhat
self.traces = {}
def __call__(self, trace, draw):
if draw.tuning:
return
self.traces[draw.chain] = trace
if len(trace) % self.every == 0:
multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))
if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:
raise KeyboardInterrupt
the error is
data[name] = np.array(self.posterior_trace.get_sampler_stats(stat, combine=False))
could not broadcast input array from shape (1942,36) into shape (2000,36)