Progress bar Callback / Hook

Here’s the documentation: pymc.sample — PyMC 5.10.0 documentation

And see the callback argument

callback function, default=None
A function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw and will contain all samples for a single trace. the draw.chain argument can be used to determine which of the active chains the sample is drawn from. Sampling can be interrupted by throwing a KeyboardInterrupt in the callback.

And here’s a small example. See it gives you so much information, which is super cool!

import numpy as np
import pymc as pm

CHAINS = 2
DRAWS = 1
TUNE = 1

x = np.linspace(0, 10)
y = 5 + x + np.random.normal(size=50)

with pm.Model() as model:
    a = pm.Normal("a")
    b = pm.Normal("b")
    mu = a + b * x
    sigma = pm.HalfNormal("sigma")
    pm.Normal("y", mu=mu, sigma=sigma, observed=y)

def callback(**kwargs):
    print(kwargs["trace"])
    print(kwargs["draw"])

with model:
    pm.sample(chains=CHAINS, tune=TUNE, draws=DRAWS, callback=callback)
Draw(chain=0, is_last=False, draw_idx=0, tuning=True, stats=[{'tune': True, 'diverging': True, 'perf_counter_diff': 0.001911101000018789, 'process_time_diff': 0.0019091170000000005, 'perf_counter_start': 454.54123655, 'warning': SamplerWarning(kind=<WarningType.TUNING_DIVERGENCE: 2>, message='Energy change in leapfrog step is too large: 1.8746229032034632e+144.', level='debug', step=0, exec_info=None, extra=None, divergence_point_source=None, divergence_point_dest=None, divergence_info=None), 'depth': 1, 'mean_tree_accept': 0.0, 'energy_error': 0.0, 'energy': 2405.550024556053, 'tree_size': 1, 'max_energy_error': 1.8746229032034632e+144, 'model_logp': array(-2403.71759852), 'index_in_trajectory': 0, 'reached_max_treedepth': False, 'step_size': 0.4435663891103336, 'step_size_bar': 0.4435663891103336, 'largest_eigval': nan, 'smallest_eigval': nan}], point={'a': array(0.29361544), 'b': array(0.32976448), 'sigma_log__': array(-0.14467011)})
Draw(chain=0, is_last=True, draw_idx=1, tuning=False, stats=[{'tune': False, 'diverging': True, 'perf_counter_diff': 0.0004060370000047442, 'process_time_diff': 0.00040501100000000026, 'perf_counter_start': 454.543718379, 'warning': SamplerWarning(kind=<WarningType.DIVERGENCE: 1>, message='Energy change in leapfrog step is too large: inf.', level='debug', step=1, exec_info=None, extra=None, divergence_point_source={'a': array(0.29361544), 'b': array(0.32976448), 'sigma_log__': array(-0.14467011)}, divergence_point_dest={'a': array(53.78360997), 'b': array(305.86961388), 'sigma_log__': array(459.63442126)}, divergence_info=DivergenceInfo(message='Energy change in leapfrog step is too large: inf.', exec_info=None, state=State(q=RaveledVars(data=array([ 0.29361544,  0.32976448, -0.14467011]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), p=RaveledVars(data=array([-0.80052794, -0.28900898,  0.26328982]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), v=array([-0.80052794, -0.28900898,  0.26328982]), q_grad=array([ 540.1229881 , 3104.54772793, 4674.8992078 ]), energy=2404.1144448671844, model_logp=array(-2403.71759852), index_in_trajectory=0), state_div=State(q=RaveledVars(data=array([ 53.78360997, 305.86961388, 459.63442126]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), p=RaveledVars(data=array([-108.66242886, -620.98878164,           inf]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), v=array([-108.66242886, -620.98878164,           inf]), q_grad=array([ -53.78360997, -305.86961388,          -inf]), energy=inf, model_logp=array(-inf), index_in_trajectory=-1))), 'depth': 1, 'mean_tree_accept': 0.0, 'energy_error': 0.0, 'energy': 2404.1144448671844, 'tree_size': 1, 'max_energy_error': inf, 'model_logp': array(-2403.71759852), 'index_in_trajectory': 0, 'reached_max_treedepth': False, 'step_size': 0.4435663891103336, 'step_size_bar': 0.4435663891103336, 'largest_eigval': nan, 'smallest_eigval': nan}], point={'a': array(0.29361544), 'b': array(0.32976448), 'sigma_log__': array(-0.14467011)})
Draw(chain=1, is_last=False, draw_idx=0, tuning=True, stats=[{'tune': True, 'diverging': True, 'perf_counter_diff': 0.0016176770000129181, 'process_time_diff': 0.0016152749999999998, 'perf_counter_start': 454.542239224, 'warning': SamplerWarning(kind=<WarningType.TUNING_DIVERGENCE: 2>, message='Energy change in leapfrog step is too large: 8.692794022538529e+39.', level='debug', step=0, exec_info=None, extra=None, divergence_point_source=None, divergence_point_dest=None, divergence_info=None), 'depth': 1, 'mean_tree_accept': 0.0, 'energy_error': 0.0, 'energy': 774.3912185967521, 'tree_size': 1, 'max_energy_error': 8.692794022538529e+39, 'model_logp': array(-773.62999892), 'index_in_trajectory': 0, 'reached_max_treedepth': False, 'step_size': 0.4435663891103336, 'step_size_bar': 0.4435663891103336, 'largest_eigval': nan, 'smallest_eigval': nan}], point={'a': array(-0.05543783), 'b': array(-0.30937351), 'sigma_log__': array(0.85966873)})
Draw(chain=1, is_last=True, draw_idx=1, tuning=False, stats=[{'tune': False, 'diverging': True, 'perf_counter_diff': 0.00042231799994851826, 'process_time_diff': 0.00042210300000000006, 'perf_counter_start': 454.544337308, 'warning': SamplerWarning(kind=<WarningType.DIVERGENCE: 1>, message='Energy change in leapfrog step is too large: 2.279319132558379e+222.', level='debug', step=1, exec_info=None, extra=None, divergence_point_source={'a': array(-0.05543783), 'b': array(-0.30937351), 'sigma_log__': array(0.85966873)}, divergence_point_dest={'a': array(10.4403694), 'b': array(61.72045712), 'sigma_log__': array(128.92575614)}, divergence_info=DivergenceInfo(message='Energy change in leapfrog step is too large: 2.279319132558379e+222.', exec_info=None, state=State(q=RaveledVars(data=array([-0.05543783, -0.30937351,  0.85966873]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), p=RaveledVars(data=array([-0.5260765 , -1.15454405,  1.09184011]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), v=array([-0.5260765 , -1.15454405,  1.09184011]), q_grad=array([ 104.31916481,  625.33526582, 1306.73082727]), energy=775.0309205623541, model_logp=array(-773.62999892), index_in_trajectory=0), state_div=State(q=RaveledVars(data=array([ 10.4403694 ,  61.72045712, 128.92575614]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), p=RaveledVars(data=array([-2.13468156e+001, -1.26154837e+002,  2.13509678e+111]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), v=array([-2.13468156e+001, -1.26154837e+002,  2.13509678e+111]), q_grad=array([-1.04403694e+001, -6.17204571e+001, -9.62695476e+111]), energy=2.279319132558379e+222, model_logp=array(-4.81347738e+111), index_in_trajectory=-1))), 'depth': 1, 'mean_tree_accept': 0.0, 'energy_error': 0.0, 'energy': 775.0309205623541, 'tree_size': 1, 'max_energy_error': 2.279319132558379e+222, 'model_logp': array(-773.62999892), 'index_in_trajectory': 0, 'reached_max_treedepth': False, 'step_size': 0.4435663891103336, 'step_size_bar': 0.4435663891103336, 'largest_eigval': nan, 'smallest_eigval': nan}], point={'a': array(-0.05543783), 'b': array(-0.30937351), 'sigma_log__': array(0.85966873)})
3 Likes