Arviz KDE always results in FloatingPointError: underflow encountered in exp

I am working on a custom hit-and-run sampling algorithm for distributions with polytope support. I would love to use Arviz to explore the results, but I run into the same underflow error for every result I try to analyse. My algorithm outputs an InferenceData object called infdat and the first thing I would like to do is to check the traces of the chains. I am first analysing a small model with one variable with six coordinates.

print(infdat.posterior)

gives

<xarray.Dataset>
Dimensions:   (chain: 4, draw: 5000, theta_id: 6)
Coordinates:
  * chain     (chain) int32 0 1 2 3
  * draw      (draw) int32 0 1 2 3 4 5 6 ... 4993 4994 4995 4996 4997 4998 4999
  * theta_id  (theta_id) <U6 'B_svd0' 'B_svd1' 'B_svd2' ... 'B_svd4' 'v2_xch'
Data variables:
    param     (chain, draw, theta_id) float64 0.0 0.0 0.0 ... 0.7464 1.29 0.4459
Attributes:
    created_at:     2023-07-06T09:44:13.653841
    arviz_version:  0.12.1

When I now try to plot the traces with az.plot_trace(infdat), I get the following error:

---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
Cell In [114], line 1
----> 1 az.plot_trace(
      2     infdatt.posterior, 
      3     plot_kwargs={'bw': 'scott'},
      4 )

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\plots\traceplot.py:260, in plot_trace(data, var_names, filter_vars, transform, coords, divergences, kind, figsize, rug, lines, circ_var_names, circ_var_units, compact, compact_prop, combined, chain_prop, legend, plot_kwargs, fill_kwargs, rug_kwargs, hist_kwargs, trace_kwargs, rank_kwargs, labeller, axes, backend, backend_config, backend_kwargs, show)
    257 backend = backend.lower()
    259 plot = get_plotting_function("plot_trace", "traceplot", backend)
--> 260 axes = plot(**trace_plot_args)
    262 return axes

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\plots\backends\matplotlib\traceplot.py:288, in plot_trace(data, var_names, divergences, kind, figsize, rug, lines, circ_var_names, circ_var_units, compact, compact_prop, combined, chain_prop, legend, labeller, plot_kwargs, fill_kwargs, rug_kwargs, hist_kwargs, trace_kwargs, rank_kwargs, plotters, divergence_data, axes, backend_kwargs, backend_config, show)
    286 aux_plot_kwargs = dealiase_sel_kwargs(plot_kwargs, compact_prop_iter, sub_idx)
    287 aux_trace_kwargs = dealiase_sel_kwargs(trace_kwargs, compact_prop_iter, sub_idx)
--> 288 ax = _plot_chains_mpl(
    289     ax,
    290     idy,
    291     value[..., sub_idx],
    292     data,
    293     chain_prop,
    294     combined,
    295     xt_labelsize,
    296     rug,
    297     kind,
    298     aux_trace_kwargs,
    299     hist_kwargs,
    300     aux_plot_kwargs,
    301     fill_kwargs,
    302     rug_kwargs,
    303     rank_kwargs,
    304     circular,
    305     circ_var_units,
    306     circ_units_trace,
    307 )
    308 if legend:
    309     handles.append(
    310         Line2D(
    311             [],
   (...)
    315         )
    316     )

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\plots\backends\matplotlib\traceplot.py:488, in _plot_chains_mpl(axes, idy, value, data, chain_prop, combined, xt_labelsize, rug, kind, trace_kwargs, hist_kwargs, plot_kwargs, fill_kwargs, rug_kwargs, rank_kwargs, circular, circ_var_units, circ_units_trace)
    486         aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, chain_idx)
    487         if not idy:
--> 488             axes = plot_dist(
    489                 values=row,
    490                 textsize=xt_labelsize,
    491                 rug=rug,
    492                 ax=axes,
    493                 hist_kwargs=hist_kwargs,
    494                 plot_kwargs=aux_kwargs,
    495                 fill_kwargs=fill_kwargs,
    496                 rug_kwargs=rug_kwargs,
    497                 backend="matplotlib",
    498                 show=False,
    499                 is_circular=circ_var_units,
    500             )
    502 if kind == "rank_bars" and idy:
    503     axes = plot_rank(data=value, kind="bars", ax=axes, **rank_kwargs)

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\plots\distplot.py:233, in plot_dist(values, values2, color, kind, cumulative, label, rotated, rug, bw, quantiles, contour, fill_last, figsize, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, hist_kwargs, is_circular, ax, backend, backend_kwargs, show, **kwargs)
    230 backend = backend.lower()
    232 plot = get_plotting_function("plot_dist", "distplot", backend)
--> 233 ax = plot(**dist_plot_args)
    234 return ax

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:92, in plot_dist(values, values2, color, kind, cumulative, label, rotated, rug, bw, quantiles, contour, fill_last, figsize, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, hist_kwargs, is_circular, ax, backend_kwargs, show)
     89     plot_kwargs.setdefault("color", color)
     90     legend = label is not None
---> 92     ax = plot_kde(
     93         values,
     94         values2,
     95         cumulative=cumulative,
     96         rug=rug,
     97         label=label,
     98         bw=bw,
     99         quantiles=quantiles,
    100         rotated=rotated,
    101         contour=contour,
    102         legend=legend,
    103         fill_last=fill_last,
    104         textsize=textsize,
    105         plot_kwargs=plot_kwargs,
    106         fill_kwargs=fill_kwargs,
    107         rug_kwargs=rug_kwargs,
    108         contour_kwargs=contour_kwargs,
    109         contourf_kwargs=contourf_kwargs,
    110         pcolormesh_kwargs=pcolormesh_kwargs,
    111         ax=ax,
    112         backend="matplotlib",
    113         backend_kwargs=backend_kwargs,
    114         is_circular=is_circular,
    115         show=show,
    116     )
    118 if backend_show(show):
    119     plt.show()

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\plots\kdeplot.py:270, in plot_kde(values, values2, cumulative, rug, label, bw, adaptive, quantiles, rotated, contour, hdi_probs, fill_last, figsize, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, is_circular, ax, legend, backend, backend_kwargs, show, return_glyph, **kwargs)
    267     else:
    268         bw = "experimental"
--> 270 grid, density = kde(values, is_circular, bw=bw, adaptive=adaptive, cumulative=cumulative)
    271 lower, upper = grid[0], grid[-1]
    273 if cumulative:

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:502, in kde(x, circular, **kwargs)
    499 else:
    500     kde_fun = _kde_linear
--> 502 return kde_fun(x, **kwargs)

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:588, in _kde_linear(x, bw, adaptive, extend, bound_correction, extend_fct, bw_fct, bw_return, custom_lims, cumulative, grid_len, **kwargs)
    585 grid_counts, _, grid_edges = histogram(x, grid_len, (grid_min, grid_max))
    587 # Bandwidth estimation
--> 588 bw = bw_fct * _get_bw(x, bw, grid_counts, x_std, x_range)
    590 # Density estimation
    591 if adaptive:

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:157, in _get_bw(x, bw, grid_counts, x_std, x_range)
    150         raise ValueError(
    151             "Unrecognized bandwidth method.\n"
    152             f"Input is: {bw_lower}.\n"
    153             f"Expected one of: {list(_BW_METHODS_LINEAR)}."
    154         )
    156     bw_fun = _BW_METHODS_LINEAR[bw_lower]
--> 157     bw = bw_fun(x, grid_counts=grid_counts, x_std=x_std, x_range=x_range)
    158 else:
    159     raise ValueError(
    160         "Unrecognized `bw` argument.\n"
    161         "Expected a positive numeric or one of the following strings:\n"
    162         f"{list(_BW_METHODS_LINEAR)}."
    163     )

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:81, in _bw_experimental(x, grid_counts, x_std, x_range)
     79 """Experimental bandwidth estimator."""
     80 bw_silverman = _bw_silverman(x, x_std=x_std)
---> 81 bw_isj = _bw_isj(x, grid_counts=grid_counts, x_range=x_range)
     82 return 0.5 * (bw_silverman + bw_isj)

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:73, in _bw_isj(x, grid_counts, x_std, x_range)
     70 k_sq = np.arange(1, grid_len) ** 2
     71 a_sq = a_k[range(1, grid_len)] ** 2
---> 73 t = _root(_fixed_point, x_len, args=(x_len, k_sq, a_sq), x=x)
     74 h = t**0.5 * x_range
     75 return h

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:261, in _root(function, N, args, x)
    259 while not found:
    260     try:
--> 261         bw, res = brentq(function, 0, 0.01, args=args, full_output=True, disp=False)
    262         found = res.converged
    263     except ValueError:

File ~\AppData\Roaming\Python\Python38\site-packages\scipy\optimize\_zeros_py.py:783, in brentq(f, a, b, args, xtol, rtol, maxiter, full_output, disp)
    781 if rtol < _rtol:
    782     raise ValueError("rtol too small (%g < %g)" % (rtol, _rtol))
--> 783 r = _zeros._brentq(f, a, b, xtol, rtol, maxiter, args, full_output, disp)
    784 return results_c(full_output, r)

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:238, in _fixed_point(t, N, k_sq, a_sq)
    235 a_sq = np.asfarray(a_sq, dtype=np.float64)
    237 l = 7
--> 238 f = np.sum(np.power(k_sq, l) * a_sq * np.exp(-k_sq * np.pi**2 * t))
    239 f *= 0.5 * np.pi ** (2.0 * l)
    241 for j in np.arange(l - 1, 2 - 1, -1):

FloatingPointError: underflow encountered in exp

Questions:
How can I plot traces without KDE plots on the side? Alternatively, how can I get the KDE stuff to work? Last, can I just use histograms for these variables instead of KDE? I tried instructing arviz to use histograms in the trace_plot via the plot_kwargs argument, but this failed and again gave the above error.

Some observations I have made that might help thinking about this:

  1. az.plot_forest(infdat) works
  2. az.plot_density(infdat, bw='scott') works, which makes me think that I should be able to somehow instruct plot_trace to use ‘scott’ to do the KDE (using the plot_kwargs for this fails and gives me the exact same error as above)
  3. Even for chains with only 4 draws the KDE fails, which is still in the warmup phase of the chain, thus making it unlikely that we are drawing values from a very tight posterior region.
  4. The KDE fails for every single variable individually, which is odd since some variables are numbers between 0 and 1 and others are variables and others lie together in a polytope

I have now noticed I also get underflow errors for az.plot_ppc(infdat, data_pairs={"observed_data": "simulated_data"}). Can I somehow just get histogrammed plots for continuous variable, since I think that will have fewer numerical issues than using KDE everywhere.

1 Like

Welcome!

I suspect that the bw="scott" should be passed in as hist_kwargs rather than plot_kwargs.

The default is to throw out warmup/tuning samples, so if you request “4 samples” (e.g., `idata=pm.sample(4)) you are most likely getting 4 draws that occur after the (default 1000) tuning samples.

The other obvious way to investigate would be to dive into the inference data itself. Once you figure out what’s going on, then maybe you get things working with ArviZ. @OriolAbril suggested this as a potential shortcut to get just the trace (and not the KDE):

plotters = list(az.sel_utils.xarray_var_iter(data.posterior, combined=True, dim_order=["draw", "chain"]))
n_plots = len(plotters)
fig, axes = plt.subplots(n_plots//3, 3) # or variations
for ax, (var_name, sel, isel, var_data) in zip(axes.ravel(), plotters):
    ax.plot(var_data, lw=1, alpha=.5)
    ax.set_title(var_name)
plt.show()
1 Like

Or maybe:

data.posterior.to_stacked_array("var", sample_dims=["chain", "draw"]).plot.line(hue="chain", col="var", col_wrap=4, sharey=False)
plt.show()

But I’m just parroting what @OriolAbril is feeding me.

another option is to set az.rcParams["plot.density_kind"] = "hist" which might be a bit more robust. If a chain is stuck for example it is not possible to compute the kde. I think this answers the last question about histogram as default. You can also set that in your arvizrc file, all available parameters can be found at arviz/arvizrc.template at main · arviz-devs/arviz · GitHub, some documentation on them at ArviZ customization with rcParams — Oriol unraveled

There is no way right now to pass bw argument to kde via plot_trace, but a PR would be welcome.

Thank you so much! This command az.rcParams["plot.density_kind"] = "hist" helped and resulted in the plot below:

It seems that the chains are not stuck, which means that there is something else going on in the KDE estimation.