Is there anyway to use the callback functionality of pm.sample() to get the current status of the sampling progress? Ideally I’m doing some sort of logging of the sampling progress to update a status page.
Cheers!
Is there anyway to use the callback functionality of pm.sample() to get the current status of the sampling progress? Ideally I’m doing some sort of logging of the sampling progress to update a status page.
Cheers!
Here’s the documentation: pymc.sample — PyMC 5.10.0 documentation
And see the callback
argument
callback
function
, default=None
A function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw and will contain all samples for a single trace. thedraw.chain
argument can be used to determine which of the active chains the sample is drawn from. Sampling can be interrupted by throwing aKeyboardInterrupt
in the callback.
And here’s a small example. See it gives you so much information, which is super cool!
import numpy as np
import pymc as pm
CHAINS = 2
DRAWS = 1
TUNE = 1
x = np.linspace(0, 10)
y = 5 + x + np.random.normal(size=50)
with pm.Model() as model:
a = pm.Normal("a")
b = pm.Normal("b")
mu = a + b * x
sigma = pm.HalfNormal("sigma")
pm.Normal("y", mu=mu, sigma=sigma, observed=y)
def callback(**kwargs):
print(kwargs["trace"])
print(kwargs["draw"])
with model:
pm.sample(chains=CHAINS, tune=TUNE, draws=DRAWS, callback=callback)
Draw(chain=0, is_last=False, draw_idx=0, tuning=True, stats=[{'tune': True, 'diverging': True, 'perf_counter_diff': 0.001911101000018789, 'process_time_diff': 0.0019091170000000005, 'perf_counter_start': 454.54123655, 'warning': SamplerWarning(kind=<WarningType.TUNING_DIVERGENCE: 2>, message='Energy change in leapfrog step is too large: 1.8746229032034632e+144.', level='debug', step=0, exec_info=None, extra=None, divergence_point_source=None, divergence_point_dest=None, divergence_info=None), 'depth': 1, 'mean_tree_accept': 0.0, 'energy_error': 0.0, 'energy': 2405.550024556053, 'tree_size': 1, 'max_energy_error': 1.8746229032034632e+144, 'model_logp': array(-2403.71759852), 'index_in_trajectory': 0, 'reached_max_treedepth': False, 'step_size': 0.4435663891103336, 'step_size_bar': 0.4435663891103336, 'largest_eigval': nan, 'smallest_eigval': nan}], point={'a': array(0.29361544), 'b': array(0.32976448), 'sigma_log__': array(-0.14467011)})
Draw(chain=0, is_last=True, draw_idx=1, tuning=False, stats=[{'tune': False, 'diverging': True, 'perf_counter_diff': 0.0004060370000047442, 'process_time_diff': 0.00040501100000000026, 'perf_counter_start': 454.543718379, 'warning': SamplerWarning(kind=<WarningType.DIVERGENCE: 1>, message='Energy change in leapfrog step is too large: inf.', level='debug', step=1, exec_info=None, extra=None, divergence_point_source={'a': array(0.29361544), 'b': array(0.32976448), 'sigma_log__': array(-0.14467011)}, divergence_point_dest={'a': array(53.78360997), 'b': array(305.86961388), 'sigma_log__': array(459.63442126)}, divergence_info=DivergenceInfo(message='Energy change in leapfrog step is too large: inf.', exec_info=None, state=State(q=RaveledVars(data=array([ 0.29361544, 0.32976448, -0.14467011]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), p=RaveledVars(data=array([-0.80052794, -0.28900898, 0.26328982]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), v=array([-0.80052794, -0.28900898, 0.26328982]), q_grad=array([ 540.1229881 , 3104.54772793, 4674.8992078 ]), energy=2404.1144448671844, model_logp=array(-2403.71759852), index_in_trajectory=0), state_div=State(q=RaveledVars(data=array([ 53.78360997, 305.86961388, 459.63442126]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), p=RaveledVars(data=array([-108.66242886, -620.98878164, inf]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), v=array([-108.66242886, -620.98878164, inf]), q_grad=array([ -53.78360997, -305.86961388, -inf]), energy=inf, model_logp=array(-inf), index_in_trajectory=-1))), 'depth': 1, 'mean_tree_accept': 0.0, 'energy_error': 0.0, 'energy': 2404.1144448671844, 'tree_size': 1, 'max_energy_error': inf, 'model_logp': array(-2403.71759852), 'index_in_trajectory': 0, 'reached_max_treedepth': False, 'step_size': 0.4435663891103336, 'step_size_bar': 0.4435663891103336, 'largest_eigval': nan, 'smallest_eigval': nan}], point={'a': array(0.29361544), 'b': array(0.32976448), 'sigma_log__': array(-0.14467011)})
Draw(chain=1, is_last=False, draw_idx=0, tuning=True, stats=[{'tune': True, 'diverging': True, 'perf_counter_diff': 0.0016176770000129181, 'process_time_diff': 0.0016152749999999998, 'perf_counter_start': 454.542239224, 'warning': SamplerWarning(kind=<WarningType.TUNING_DIVERGENCE: 2>, message='Energy change in leapfrog step is too large: 8.692794022538529e+39.', level='debug', step=0, exec_info=None, extra=None, divergence_point_source=None, divergence_point_dest=None, divergence_info=None), 'depth': 1, 'mean_tree_accept': 0.0, 'energy_error': 0.0, 'energy': 774.3912185967521, 'tree_size': 1, 'max_energy_error': 8.692794022538529e+39, 'model_logp': array(-773.62999892), 'index_in_trajectory': 0, 'reached_max_treedepth': False, 'step_size': 0.4435663891103336, 'step_size_bar': 0.4435663891103336, 'largest_eigval': nan, 'smallest_eigval': nan}], point={'a': array(-0.05543783), 'b': array(-0.30937351), 'sigma_log__': array(0.85966873)})
Draw(chain=1, is_last=True, draw_idx=1, tuning=False, stats=[{'tune': False, 'diverging': True, 'perf_counter_diff': 0.00042231799994851826, 'process_time_diff': 0.00042210300000000006, 'perf_counter_start': 454.544337308, 'warning': SamplerWarning(kind=<WarningType.DIVERGENCE: 1>, message='Energy change in leapfrog step is too large: 2.279319132558379e+222.', level='debug', step=1, exec_info=None, extra=None, divergence_point_source={'a': array(-0.05543783), 'b': array(-0.30937351), 'sigma_log__': array(0.85966873)}, divergence_point_dest={'a': array(10.4403694), 'b': array(61.72045712), 'sigma_log__': array(128.92575614)}, divergence_info=DivergenceInfo(message='Energy change in leapfrog step is too large: 2.279319132558379e+222.', exec_info=None, state=State(q=RaveledVars(data=array([-0.05543783, -0.30937351, 0.85966873]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), p=RaveledVars(data=array([-0.5260765 , -1.15454405, 1.09184011]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), v=array([-0.5260765 , -1.15454405, 1.09184011]), q_grad=array([ 104.31916481, 625.33526582, 1306.73082727]), energy=775.0309205623541, model_logp=array(-773.62999892), index_in_trajectory=0), state_div=State(q=RaveledVars(data=array([ 10.4403694 , 61.72045712, 128.92575614]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), p=RaveledVars(data=array([-2.13468156e+001, -1.26154837e+002, 2.13509678e+111]), point_map_info=(('a', (), dtype('float64')), ('b', (), dtype('float64')), ('sigma_log__', (), dtype('float64')))), v=array([-2.13468156e+001, -1.26154837e+002, 2.13509678e+111]), q_grad=array([-1.04403694e+001, -6.17204571e+001, -9.62695476e+111]), energy=2.279319132558379e+222, model_logp=array(-4.81347738e+111), index_in_trajectory=-1))), 'depth': 1, 'mean_tree_accept': 0.0, 'energy_error': 0.0, 'energy': 775.0309205623541, 'tree_size': 1, 'max_energy_error': 2.279319132558379e+222, 'model_logp': array(-773.62999892), 'index_in_trajectory': 0, 'reached_max_treedepth': False, 'step_size': 0.4435663891103336, 'step_size_bar': 0.4435663891103336, 'largest_eigval': nan, 'smallest_eigval': nan}], point={'a': array(-0.05543783), 'b': array(-0.30937351), 'sigma_log__': array(0.85966873)})