Compute reduced summary statistics

Normally I compute summaries of traces with pm.summary(trace) but this can be quite slow for large problems. Is there an argument to specify that I only want mean and standard deviation?

https://docs.pymc.io/api/stats.html?highlight=pymc3%20summary#pymc3.stats.summary specificies that I can pass a func_dict of functions, but I don’t see a pm.stats.mean function, for example. Does PyMC3 use np.mean under the hood?

1 Like

You can reduce the number of variables that pm.summary() calculates all the stats for by using the var_names parameter. As for only computing a few of the stats I don’t think that is possible. To check model convergence you should calculate and check rhat and ess at least once after sampling.

Alternatively you can write a simple function to calc just mean and std using numpy and return a DataFrame.

It looks like you want something like the second example in the documentation, that is, passing a custom stats_funcs dictionary and using extend=False. Example code below:

stats_funcs = {"mean": np.mean, "std": np.std}
az.summary(idata, stats_funcs=stats_funcs, extend=False, kind="stats")

I’d also like to point out a couple of extra details to make the answer more complete.

The first comment is that PyMC3 now uses ArviZ for statistics and plotting, you’ll see that I have used idata instead of trace. ArviZ uses its own data structure (aiming to be backend agnostic and therefore be used not only by PyMC3 users but also by PyStan, Pyro…). If performace is critical, you should convert your trace to InferenceData using az.from_pymc3, otherwise, the conversion is done internally every time you call an ArviZ function like az.summary(trace,...).

The second comment is that ArviZ uses xarray.apply_ufunc with a wrapper around the functions in stats_funcs. This wrapper allows users to use a wider range of functions (otherwise it would be limited to functions following numpy ufuncs conventions) but for functions which are already pure ufuncs, it may not be as efficient as doing something like idata.posterior.mean(dim=("chain", "draw")). By default, az.summary uses this second approach, but to customize the output, the wrapper must be used. This being said, unless the model has a relatively large number of variables, this should not affect performance significantly.

And finally, I have used the kind argument, present only in ArviZ development version to avoid the computation of diagnostics (ess, rhat, mcse). This is bugish, even though the values for the diagnostics are not returned when extend=False, they are still executed, unlike diagnostics, default stats functions are not executed if exclude=False.

1 Like

Thank you – that’s incredibly helpful. Marking as solution.

1 Like