Progress bar Callback / Hook

Is there anyway to use the callback functionality of pm.sample() to get the current status of the sampling progress? Ideally I’m doing some sort of logging of the sampling progress to update a status page.

Cheers!

Here’s the documentation: pymc.sample — PyMC 5.10.0 documentation

And see the callback argument

callback function, default=None
A function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw and will contain all samples for a single trace. the draw.chain argument can be used to determine which of the active chains the sample is drawn from. Sampling can be interrupted by throwing a KeyboardInterrupt in the callback.

And here’s a small example. See it gives you so much information, which is super cool!

import numpy as np
import pymc as pm

CHAINS = 2
DRAWS = 1
TUNE = 1

x = np.linspace(0, 10)
y = 5 + x + np.random.normal(size=50)

with pm.Model() as model:
    a = pm.Normal("a")
    b = pm.Normal("b")
    mu = a + b * x
    sigma = pm.HalfNormal("sigma")
    pm.Normal("y", mu=mu, sigma=sigma, observed=y)

def callback(**kwargs):
    print(kwargs["trace"])
    print(kwargs["draw"])

with model:
    pm.sample(chains=CHAINS, tune=TUNE, draws=DRAWS, callback=callback)
Draw(chain=0, is_last=False, draw_idx=0, tuning=True, stats=[{'tune': True, 'diverging': True, 'perf_counter_diff': 0.001911101000018789, 'process_time_diff': 0.0019091170000000005, 'perf_counter_start': 454.54123655, 'warning': SamplerWarning(kind=<WarningType.TUNING_DIVERGENCE: 2>, message='Energy change in leapfrog step is too large: 1.8746229032034632e+144.', level='debug', step=0, exec_info=None, extra=None, divergence_point_source=None, divergence_point_dest=None, divergence_info=None), 'depth': 1, 'mean_tree_accept': 0.0, 'energy_error': 0.0, 'energy': 2405.550024556053, 'tree_size': 1, 'max_energy_error': 1.8746229032034632e+144, 'model_logp': array(-2403.71759852), 'index_in_trajectory': 0, 'reached_max_treedepth': False, 'step_size': 0.4435663891103336, 'step_size_bar': 0.4435663891103336, 'largest_eigval': nan, 'smallest_eigval': nan}], point={'a': array(0.29361544), 'b': array(0.32976448), 'sigma_log__': array(-0.14467011)})
Draw(chain=0, is_last=True, draw_idx=1, tuning=False, stats=[{'tune': False, 'diverging': True, 'perf_counter_diff': 0.0004060370000047442, 'process_time_diff': 0.00040501100000000026, 'perf_counter_start': 454.543718379, 'warning': SamplerWarning(kind=<WarningType.DIVERGENCE: 1>, message='Energy change in leapfrog step is too large: inf.', level='debug', step=1, exec_info=None, extra=None, divergence_point_source={'a': array(0.29361544), 'b': array(0.32976448), 'sigma_log__': array(-0.14467011)}, divergence_point_dest={'a': array(53.78360997), 'b': array(305.86961388), 'sigma_log__': array(459.63442126)}, divergence_info=DivergenceInfo(message='Energy change in leapfrog step is too large: inf.', exec_info=None, state=State(q=RaveledVars(data=array([ 0.29361544,  0.32976448, -0.14467011]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), p=RaveledVars(data=array([-0.80052794, -0.28900898,  0.26328982]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), v=array([-0.80052794, -0.28900898,  0.26328982]), q_grad=array([ 540.1229881 , 3104.54772793, 4674.8992078 ]), energy=2404.1144448671844, model_logp=array(-2403.71759852), index_in_trajectory=0), state_div=State(q=RaveledVars(data=array([ 53.78360997, 305.86961388, 459.63442126]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), p=RaveledVars(data=array([-108.66242886, -620.98878164,           inf]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), v=array([-108.66242886, -620.98878164,           inf]), q_grad=array([ -53.78360997, -305.86961388,          -inf]), energy=inf, model_logp=array(-inf), index_in_trajectory=-1))), 'depth': 1, 'mean_tree_accept': 0.0, 'energy_error': 0.0, 'energy': 2404.1144448671844, 'tree_size': 1, 'max_energy_error': inf, 'model_logp': array(-2403.71759852), 'index_in_trajectory': 0, 'reached_max_treedepth': False, 'step_size': 0.4435663891103336, 'step_size_bar': 0.4435663891103336, 'largest_eigval': nan, 'smallest_eigval': nan}], point={'a': array(0.29361544), 'b': array(0.32976448), 'sigma_log__': array(-0.14467011)})
Draw(chain=1, is_last=False, draw_idx=0, tuning=True, stats=[{'tune': True, 'diverging': True, 'perf_counter_diff': 0.0016176770000129181, 'process_time_diff': 0.0016152749999999998, 'perf_counter_start': 454.542239224, 'warning': SamplerWarning(kind=<WarningType.TUNING_DIVERGENCE: 2>, message='Energy change in leapfrog step is too large: 8.692794022538529e+39.', level='debug', step=0, exec_info=None, extra=None, divergence_point_source=None, divergence_point_dest=None, divergence_info=None), 'depth': 1, 'mean_tree_accept': 0.0, 'energy_error': 0.0, 'energy': 774.3912185967521, 'tree_size': 1, 'max_energy_error': 8.692794022538529e+39, 'model_logp': array(-773.62999892), 'index_in_trajectory': 0, 'reached_max_treedepth': False, 'step_size': 0.4435663891103336, 'step_size_bar': 0.4435663891103336, 'largest_eigval': nan, 'smallest_eigval': nan}], point={'a': array(-0.05543783), 'b': array(-0.30937351), 'sigma_log__': array(0.85966873)})
Draw(chain=1, is_last=True, draw_idx=1, tuning=False, stats=[{'tune': False, 'diverging': True, 'perf_counter_diff': 0.00042231799994851826, 'process_time_diff': 0.00042210300000000006, 'perf_counter_start': 454.544337308, 'warning': SamplerWarning(kind=<WarningType.DIVERGENCE: 1>, message='Energy change in leapfrog step is too large: 2.279319132558379e+222.', level='debug', step=1, exec_info=None, extra=None, divergence_point_source={'a': array(-0.05543783), 'b': array(-0.30937351), 'sigma_log__': array(0.85966873)}, divergence_point_dest={'a': array(10.4403694), 'b': array(61.72045712), 'sigma_log__': array(128.92575614)}, divergence_info=DivergenceInfo(message='Energy change in leapfrog step is too large: 2.279319132558379e+222.', exec_info=None, state=State(q=RaveledVars(data=array([-0.05543783, -0.30937351,  0.85966873]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), p=RaveledVars(data=array([-0.5260765 , -1.15454405,  1.09184011]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), v=array([-0.5260765 , -1.15454405,  1.09184011]), q_grad=array([ 104.31916481,  625.33526582, 1306.73082727]), energy=775.0309205623541, model_logp=array(-773.62999892), index_in_trajectory=0), state_div=State(q=RaveledVars(data=array([ 10.4403694 ,  61.72045712, 128.92575614]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), p=RaveledVars(data=array([-2.13468156e+001, -1.26154837e+002,  2.13509678e+111]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), v=array([-2.13468156e+001, -1.26154837e+002,  2.13509678e+111]), q_grad=array([-1.04403694e+001, -6.17204571e+001, -9.62695476e+111]), energy=2.279319132558379e+222, model_logp=array(-4.81347738e+111), index_in_trajectory=-1))), 'depth': 1, 'mean_tree_accept': 0.0, 'energy_error': 0.0, 'energy': 775.0309205623541, 'tree_size': 1, 'max_energy_error': 2.279319132558379e+222, 'model_logp': array(-773.62999892), 'index_in_trajectory': 0, 'reached_max_treedepth': False, 'step_size': 0.4435663891103336, 'step_size_bar': 0.4435663891103336, 'largest_eigval': nan, 'smallest_eigval': nan}], point={'a': array(-0.05543783), 'b': array(-0.30937351), 'sigma_log__': array(0.85966873)})
2 Likes