Automated Diagnostics?

Carefully reviewing all diagnostics plots works for prototype models, but what if we are deploying 100’s into production? Has anyone made any effort to automate diagnostics in an approximate, incomplete way? I say incomplete because I have a strong prior that no automation will be able to replace careful manual review by an expert.

Here is my humble beginning I threw together. Surely I’m not the first person to try this?

class BayesDiagnostics:

    def __init__(self, idata: az.InferenceData, verbose: bool = False):
        self.idata = idata
        self.verbose = verbose
        self.rhats: Dict[str, float] = {}

    def check_trace(self):
        """Would be nice to check that all the chains are kind of the
        same. Could compute some kind of distance metric."""

        raise NotImplementedError

    def count_divergences(self):

        total_divergences = self.idata.sample_stats.diverging.to_numpy().flatten().sum()

        if self.verbose:
            print(f"Total divergences: {total_divergences}")

        return total_divergences

    def check_divergences(self):

        total_divergences = self.count_divergences()

        return total_divergences == 0

    def check_rhat(self, threshold=1.01):
        rhats = az.rhat(self.idata)
        self.rhats = {
            var_name: rhats[var_name].values.item() for var_name in rhats.data_vars
        }
        if self.verbose:
            print(self.rhats)

        return all(rhat < threshold for rhat in self.rhats.values())

    def check_rank_uniformity(self, alpha=0.10):
        """Would be nice to run uniformity tests on each chain, but requires manually
        parsing the chains and doing lots of work. Ignore for now."""

        raise NotImplementedError

    def check_autocorr(self, threshold=0.2, max_lag=20):
        """Would be nice to compute autocorrelations, but requires manually
        parsing the chains to numpy. Ignore for now."""

        raise NotImplementedError

    def parse_ess(self, idata):

        ess_values = az.ess(idata)

        results = {}
        for param, ess in ess_values.items():
            assert (
                ess.size == 1
            ), f"ESS for parameter '{param}' is not scalar (size = {ess.size})"
            results[param] = ess.to_numpy().item()

        ess_df = pd.DataFrame(list(results.items()), columns=["Parameter", "ESS"])

        if self.verbose:
            print(ess_df)

        return ess_df

    def check_ess(self, min_ess=1000):

        return (self.parse_ess(self.idata)["ESS"] > min_ess).all()

    def check_energy(self, alpha=0.05):
        """Would be nice to check this. Don't know how."""

        raise NotImplementedError

    def run_all_diagnostics(self):
        """And there were a lot fewer automated checks than I first
        envisioned making...
        """

        diagnostics = {
            "zero_divergences": self.check_divergences(),
            "rhats_sufficiently_low": self.check_rhat(),
            "ess_sufficiently_high": self.check_ess(),
        }

        return diagnostics

I like this idea and it is definitely the sort of thing that would be of wide use. That being said, I suspect that it may be easier to do some pandas operations (averaging, summing, min()ing, etc.) on the dataframe returned by az.summary(). You may still have to grab the divergences separately, but would that get you what you need otherwise?

Oh duh, az.summary() does have a lot of this already! I could really simplify this :smiley:

But it does not have everything.Consider these visual checks that are only currently accomplished visually:

  1. Consistent trace plots showing ‘k’ traces mostly shaped the same.
  2. Uniformity of rank plots.
  3. Autocorrelation plots…“look” reasonable. Not even sure how I make this judgement but I definitely know I should be looking at them.

These things seem to require more data, and more complex logic, than checking az.summary().

It’s not clear what “mostly shaped the same” means, but it seems like it would be captured by the r-hats, at least partially? And the autocorrelation is reflected in the ESS. But I am less clear about the rank plot.

Yeah. This is making good points that maybe the only thing not considered in all these easy az.summary() metrics already is the uniformity.

The simple check for that would be a null hypothesis rejection, but if we’re running at scale, that kind of thing would be overwhelmed by false positives.

Second order bayesian inference needed :D?

I think this is a great idea, but I agree. Those diagnostics are plots because they can’t be boiled down to a single statistic. A lot of them aren’t necessarily just binary bad / good indicators too, but they get at what the specific problem might be.

You could try a two-stage approach. Automatically track the top-line metrics, like divergences, ESS, rhats, and then send a summary or alert when model is out of bounds. Then I’d also be sure to save trace plots and other important graphical indicators associated with each run too. Additionally it’s also nice to have some plots that show something about the incoming data, because often the model doesn’t fit right because the data was weird. So while you won’t look at these most of the time, you’ll have them available to diagnose a model with divergences / bad ess / bad rhats.

1 Like