Carefully reviewing all diagnostics plots works for prototype models, but what if we are deploying 100’s into production? Has anyone made any effort to automate diagnostics in an approximate, incomplete way? I say incomplete because I have a strong prior that no automation will be able to replace careful manual review by an expert.
Here is my humble beginning I threw together. Surely I’m not the first person to try this?
class BayesDiagnostics:
def __init__(self, idata: az.InferenceData, verbose: bool = False):
self.idata = idata
self.verbose = verbose
self.rhats: Dict[str, float] = {}
def check_trace(self):
"""Would be nice to check that all the chains are kind of the
same. Could compute some kind of distance metric."""
raise NotImplementedError
def count_divergences(self):
total_divergences = self.idata.sample_stats.diverging.to_numpy().flatten().sum()
if self.verbose:
print(f"Total divergences: {total_divergences}")
return total_divergences
def check_divergences(self):
total_divergences = self.count_divergences()
return total_divergences == 0
def check_rhat(self, threshold=1.01):
rhats = az.rhat(self.idata)
self.rhats = {
var_name: rhats[var_name].values.item() for var_name in rhats.data_vars
}
if self.verbose:
print(self.rhats)
return all(rhat < threshold for rhat in self.rhats.values())
def check_rank_uniformity(self, alpha=0.10):
"""Would be nice to run uniformity tests on each chain, but requires manually
parsing the chains and doing lots of work. Ignore for now."""
raise NotImplementedError
def check_autocorr(self, threshold=0.2, max_lag=20):
"""Would be nice to compute autocorrelations, but requires manually
parsing the chains to numpy. Ignore for now."""
raise NotImplementedError
def parse_ess(self, idata):
ess_values = az.ess(idata)
results = {}
for param, ess in ess_values.items():
assert (
ess.size == 1
), f"ESS for parameter '{param}' is not scalar (size = {ess.size})"
results[param] = ess.to_numpy().item()
ess_df = pd.DataFrame(list(results.items()), columns=["Parameter", "ESS"])
if self.verbose:
print(ess_df)
return ess_df
def check_ess(self, min_ess=1000):
return (self.parse_ess(self.idata)["ESS"] > min_ess).all()
def check_energy(self, alpha=0.05):
"""Would be nice to check this. Don't know how."""
raise NotImplementedError
def run_all_diagnostics(self):
"""And there were a lot fewer automated checks than I first
envisioned making...
"""
diagnostics = {
"zero_divergences": self.check_divergences(),
"rhats_sufficiently_low": self.check_rhat(),
"ess_sufficiently_high": self.check_ess(),
}
return diagnostics