Arviz KDE always results in FloatingPointError: underflow encountered in exp

I am working on a custom hit-and-run sampling algorithm for distributions with polytope support. I would love to use Arviz to explore the results, but I run into the same underflow error for every result I try to analyse. My algorithm outputs an InferenceData object called infdat and the first thing I would like to do is to check the traces of the chains. I am first analysing a small model with one variable with six coordinates.

print(infdat.posterior)

gives

<xarray.Dataset>
Dimensions:   (chain: 4, draw: 5000, theta_id: 6)
Coordinates:
  * chain     (chain) int32 0 1 2 3
  * draw      (draw) int32 0 1 2 3 4 5 6 ... 4993 4994 4995 4996 4997 4998 4999
  * theta_id  (theta_id) <U6 'B_svd0' 'B_svd1' 'B_svd2' ... 'B_svd4' 'v2_xch'
Data variables:
    param     (chain, draw, theta_id) float64 0.0 0.0 0.0 ... 0.7464 1.29 0.4459
Attributes:
    created_at:     2023-07-06T09:44:13.653841
    arviz_version:  0.12.1

When I now try to plot the traces with az.plot_trace(infdat), I get the following error:

---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
Cell In [114], line 1
----> 1 az.plot_trace(
      2     infdatt.posterior, 
      3     plot_kwargs={'bw': 'scott'},
      4 )

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\plots\traceplot.py:260, in plot_trace(data, var_names, filter_vars, transform, coords, divergences, kind, figsize, rug, lines, circ_var_names, circ_var_units, compact, compact_prop, combined, chain_prop, legend, plot_kwargs, fill_kwargs, rug_kwargs, hist_kwargs, trace_kwargs, rank_kwargs, labeller, axes, backend, backend_config, backend_kwargs, show)
    257 backend = backend.lower()
    259 plot = get_plotting_function("plot_trace", "traceplot", backend)
--> 260 axes = plot(**trace_plot_args)
    262 return axes

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\plots\backends\matplotlib\traceplot.py:288, in plot_trace(data, var_names, divergences, kind, figsize, rug, lines, circ_var_names, circ_var_units, compact, compact_prop, combined, chain_prop, legend, labeller, plot_kwargs, fill_kwargs, rug_kwargs, hist_kwargs, trace_kwargs, rank_kwargs, plotters, divergence_data, axes, backend_kwargs, backend_config, show)
    286 aux_plot_kwargs = dealiase_sel_kwargs(plot_kwargs, compact_prop_iter, sub_idx)
    287 aux_trace_kwargs = dealiase_sel_kwargs(trace_kwargs, compact_prop_iter, sub_idx)
--> 288 ax = _plot_chains_mpl(
    289     ax,
    290     idy,
    291     value[..., sub_idx],
    292     data,
    293     chain_prop,
    294     combined,
    295     xt_labelsize,
    296     rug,
    297     kind,
    298     aux_trace_kwargs,
    299     hist_kwargs,
    300     aux_plot_kwargs,
    301     fill_kwargs,
    302     rug_kwargs,
    303     rank_kwargs,
    304     circular,
    305     circ_var_units,
    306     circ_units_trace,
    307 )
    308 if legend:
    309     handles.append(
    310         Line2D(
    311             [],
   (...)
    315         )
    316     )

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\plots\backends\matplotlib\traceplot.py:488, in _plot_chains_mpl(axes, idy, value, data, chain_prop, combined, xt_labelsize, rug, kind, trace_kwargs, hist_kwargs, plot_kwargs, fill_kwargs, rug_kwargs, rank_kwargs, circular, circ_var_units, circ_units_trace)
    486         aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, chain_idx)
    487         if not idy:
--> 488             axes = plot_dist(
    489                 values=row,
    490                 textsize=xt_labelsize,
    491                 rug=rug,
    492                 ax=axes,
    493                 hist_kwargs=hist_kwargs,
    494                 plot_kwargs=aux_kwargs,
    495                 fill_kwargs=fill_kwargs,
    496                 rug_kwargs=rug_kwargs,
    497                 backend="matplotlib",
    498                 show=False,
    499                 is_circular=circ_var_units,
    500             )
    502 if kind == "rank_bars" and idy:
    503     axes = plot_rank(data=value, kind="bars", ax=axes, **rank_kwargs)

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\plots\distplot.py:233, in plot_dist(values, values2, color, kind, cumulative, label, rotated, rug, bw, quantiles, contour, fill_last, figsize, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, hist_kwargs, is_circular, ax, backend, backend_kwargs, show, **kwargs)
    230 backend = backend.lower()
    232 plot = get_plotting_function("plot_dist", "distplot", backend)
--> 233 ax = plot(**dist_plot_args)
    234 return ax

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:92, in plot_dist(values, values2, color, kind, cumulative, label, rotated, rug, bw, quantiles, contour, fill_last, figsize, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, hist_kwargs, is_circular, ax, backend_kwargs, show)
     89     plot_kwargs.setdefault("color", color)
     90     legend = label is not None
---> 92     ax = plot_kde(
     93         values,
     94         values2,
     95         cumulative=cumulative,
     96         rug=rug,
     97         label=label,
     98         bw=bw,
     99         quantiles=quantiles,
    100         rotated=rotated,
    101         contour=contour,
    102         legend=legend,
    103         fill_last=fill_last,
    104         textsize=textsize,
    105         plot_kwargs=plot_kwargs,
    106         fill_kwargs=fill_kwargs,
    107         rug_kwargs=rug_kwargs,
    108         contour_kwargs=contour_kwargs,
    109         contourf_kwargs=contourf_kwargs,
    110         pcolormesh_kwargs=pcolormesh_kwargs,
    111         ax=ax,
    112         backend="matplotlib",
    113         backend_kwargs=backend_kwargs,
    114         is_circular=is_circular,
    115         show=show,
    116     )
    118 if backend_show(show):
    119     plt.show()

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\plots\kdeplot.py:270, in plot_kde(values, values2, cumulative, rug, label, bw, adaptive, quantiles, rotated, contour, hdi_probs, fill_last, figsize, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, is_circular, ax, legend, backend, backend_kwargs, show, return_glyph, **kwargs)
    267     else:
    268         bw = "experimental"
--> 270 grid, density = kde(values, is_circular, bw=bw, adaptive=adaptive, cumulative=cumulative)
    271 lower, upper = grid[0], grid[-1]
    273 if cumulative:

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:502, in kde(x, circular, **kwargs)
    499 else:
    500     kde_fun = _kde_linear
--> 502 return kde_fun(x, **kwargs)

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:588, in _kde_linear(x, bw, adaptive, extend, bound_correction, extend_fct, bw_fct, bw_return, custom_lims, cumulative, grid_len, **kwargs)
    585 grid_counts, _, grid_edges = histogram(x, grid_len, (grid_min, grid_max))
    587 # Bandwidth estimation
--> 588 bw = bw_fct * _get_bw(x, bw, grid_counts, x_std, x_range)
    590 # Density estimation
    591 if adaptive:

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:157, in _get_bw(x, bw, grid_counts, x_std, x_range)
    150         raise ValueError(
    151             "Unrecognized bandwidth method.\n"
    152             f"Input is: {bw_lower}.\n"
    153             f"Expected one of: {list(_BW_METHODS_LINEAR)}."
    154         )
    156     bw_fun = _BW_METHODS_LINEAR[bw_lower]
--> 157     bw = bw_fun(x, grid_counts=grid_counts, x_std=x_std, x_range=x_range)
    158 else:
    159     raise ValueError(
    160         "Unrecognized `bw` argument.\n"
    161         "Expected a positive numeric or one of the following strings:\n"
    162         f"{list(_BW_METHODS_LINEAR)}."
    163     )

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:81, in _bw_experimental(x, grid_counts, x_std, x_range)
     79 """Experimental bandwidth estimator."""
     80 bw_silverman = _bw_silverman(x, x_std=x_std)
---> 81 bw_isj = _bw_isj(x, grid_counts=grid_counts, x_range=x_range)
     82 return 0.5 * (bw_silverman + bw_isj)

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:73, in _bw_isj(x, grid_counts, x_std, x_range)
     70 k_sq = np.arange(1, grid_len) ** 2
     71 a_sq = a_k[range(1, grid_len)] ** 2
---> 73 t = _root(_fixed_point, x_len, args=(x_len, k_sq, a_sq), x=x)
     74 h = t**0.5 * x_range
     75 return h

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:261, in _root(function, N, args, x)
    259 while not found:
    260     try:
--> 261         bw, res = brentq(function, 0, 0.01, args=args, full_output=True, disp=False)
    262         found = res.converged
    263     except ValueError:

File ~\AppData\Roaming\Python\Python38\site-packages\scipy\optimize\_zeros_py.py:783, in brentq(f, a, b, args, xtol, rtol, maxiter, full_output, disp)
    781 if rtol < _rtol:
    782     raise ValueError("rtol too small (%g < %g)" % (rtol, _rtol))
--> 783 r = _zeros._brentq(f, a, b, xtol, rtol, maxiter, args, full_output, disp)
    784 return results_c(full_output, r)

File C:\Miniconda3\envs\sumo\lib\site-packages\arviz\stats\density_utils.py:238, in _fixed_point(t, N, k_sq, a_sq)
    235 a_sq = np.asfarray(a_sq, dtype=np.float64)
    237 l = 7
--> 238 f = np.sum(np.power(k_sq, l) * a_sq * np.exp(-k_sq * np.pi**2 * t))
    239 f *= 0.5 * np.pi ** (2.0 * l)
    241 for j in np.arange(l - 1, 2 - 1, -1):

FloatingPointError: underflow encountered in exp

Questions:
How can I plot traces without KDE plots on the side? Alternatively, how can I get the KDE stuff to work? Last, can I just use histograms for these variables instead of KDE? I tried instructing arviz to use histograms in the trace_plot via the plot_kwargs argument, but this failed and again gave the above error.

Some observations I have made that might help thinking about this:

  1. az.plot_forest(infdat) works
  2. az.plot_density(infdat, bw='scott') works, which makes me think that I should be able to somehow instruct plot_trace to use ā€˜scottā€™ to do the KDE (using the plot_kwargs for this fails and gives me the exact same error as above)
  3. Even for chains with only 4 draws the KDE fails, which is still in the warmup phase of the chain, thus making it unlikely that we are drawing values from a very tight posterior region.
  4. The KDE fails for every single variable individually, which is odd since some variables are numbers between 0 and 1 and others are variables and others lie together in a polytope

I have now noticed I also get underflow errors for az.plot_ppc(infdat, data_pairs={"observed_data": "simulated_data"}). Can I somehow just get histogrammed plots for continuous variable, since I think that will have fewer numerical issues than using KDE everywhere.

1 Like

Welcome!

I suspect that the bw="scott" should be passed in as hist_kwargs rather than plot_kwargs.

The default is to throw out warmup/tuning samples, so if you request ā€œ4 samplesā€ (e.g., `idata=pm.sample(4)) you are most likely getting 4 draws that occur after the (default 1000) tuning samples.

The other obvious way to investigate would be to dive into the inference data itself. Once you figure out whatā€™s going on, then maybe you get things working with ArviZ. @OriolAbril suggested this as a potential shortcut to get just the trace (and not the KDE):

plotters = list(az.sel_utils.xarray_var_iter(data.posterior, combined=True, dim_order=["draw", "chain"]))
n_plots = len(plotters)
fig, axes = plt.subplots(n_plots//3, 3) # or variations
for ax, (var_name, sel, isel, var_data) in zip(axes.ravel(), plotters):
    ax.plot(var_data, lw=1, alpha=.5)
    ax.set_title(var_name)
plt.show()
1 Like

Or maybe:

data.posterior.to_stacked_array("var", sample_dims=["chain", "draw"]).plot.line(hue="chain", col="var", col_wrap=4, sharey=False)
plt.show()

But Iā€™m just parroting what @OriolAbril is feeding me.

another option is to set az.rcParams["plot.density_kind"] = "hist" which might be a bit more robust. If a chain is stuck for example it is not possible to compute the kde. I think this answers the last question about histogram as default. You can also set that in your arvizrc file, all available parameters can be found at https://github.com/arviz-devs/arviz/blob/main/arvizrc.template, some documentation on them at https://oriolabrilpla.cat/en/blog/posts/2020/rcParams.html

There is no way right now to pass bw argument to kde via plot_trace, but a PR would be welcome.

Thank you so much! This command az.rcParams["plot.density_kind"] = "hist" helped and resulted in the plot below:

It seems that the chains are not stuck, which means that there is something else going on in the KDE estimation.