Also not sure how multiple samplers should be handled. I’d have to review the paper in depth which I won’t be able to anytime soon. I see basically two options though, that could be made to work with current ArviZ implementation with the following workarounds:
Generate one energy plot per sampler
If the way to go were to treat each NUTS as independent, you’d then have to generate an energy subplot for each, looping over this extra energy_dim_0. This could potentially be added as a feature to ArviZ. But for now you’ll have to handle it manually.
It will be something like:
stepmethods = ("NUTS(sigma)", "NUTS(sigma, slope, intercept")
_, axes = plt.subplots(1, len(stepmethods))
for i, sampler in enumerate(stepmethods):
az.plot_energy(samples.sample_stats.isel(energy_dim_0=i), ax=axes[i])
axes[i].set_title(sampler)
Combine the energy info of the multiple samplers
If the way to go were to combine the energy info of each step method into a global quantity it would be something like:
sample_stats = samples1.sample_stats
sample_stats["energy_steps"] = sample_stats["energy"]
del sample_stats["energy"]
# extra assumption, the combining of energy info is a sum
sample_stats["energy"] = sample_stats["energy_steps"].sum("energy_dim_0")
az.plot_energy(samples1)
If the way to combine the energy info of the different step methods is something that is clear and independent of sampler that could also be added to plot_energy code itself. Otherwise it’d have to be a note/example on the docs so users combine that info themselves depending on the samplers then pass that to plot_energy