Convergence quality API


Is there a way to get the convergence check informations programatically ?

I mean having these information (number of diverging samples, rhat, etc…) in variables in order to use them in scripts …

Thanks for your help.

Yup - you can extract the \hat{R} values using pm.summary while the presence of divergences can be extracted directly from the trace. Here’s an example:

with pm.Model() as model:
    x = pm.Normal('x')
    trace = pm.sample()
    n_diverging = trace['diverging'].sum()
    rhat = pm.summary(trace)['r_hat']

To obtain other statistics from the sampler, you can see the variables listed via trace.stat_names.

