So I’ve been dealing with a model that I recently discovered is highly degenerate and was generating lots of divergences and taking forever to run, so I decided to implement a custom step method which resolved the divergences part of the problem but throws an error when I attempt to call the plot_energy function from Arviz on the inference data. I have replicated this in a more minimal example:

```
{
"name": "ValueError",
"message": "('chain', 'draw') must be a permuted list of ('chain', 'draw', 'energy_dim_0'), unless `...` is included",
"stack": "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File c:\\Users\\ck21395\\PhD codes\\Neville_codes\\HMC Investigations\\Customsampling.py:305
302 print(az.summary(data=samples1,var_names=[\"sigma\",\"slope\",\"intercept\"]))
304 if __name__=='__main__':
--> 305 main()
File c:\\Users\\ck21395\\PhD codes\\Neville_codes\\HMC Investigations\\Customsampling.py:300
297 print(samples1['posterior'])
299 az.plot_trace(data=samples1,var_names=[\"sigma\",\"slope\",\"intercept\"],divergences=True)
--> 300 az.plot_energy(data=samples1)
301 az.plot_pair(samples1, var_names=[\"sigma\",\"slope\",\"intercept\"], divergences=True)
302 print(az.summary(data=samples1,var_names=[\"sigma\",\"slope\",\"intercept\"]))
File c:\\Users\\ck21395\\Anaconda3\\envs\\pymc_env2\\Lib\\site-packages\\arviz\\plots\\energyplot.py:109, in plot_energy(data, kind, bfmi, figsize, legend, fill_alpha, fill_color, bw, textsize, fill_kwargs, plot_kwargs, ax, backend, backend_kwargs, show)
9 def plot_energy(
10 data,
11 kind=None,
(...)
24 show=None,
25 ):
26 r\"\"\"Plot energy transition distribution and marginal energy distribution in HMC algorithms.
27
28 This may help to diagnose poor exploration by gradient-based algorithms like HMC or NUTS.
(...)
107
108 \"\"\"
--> 109 energy = convert_to_dataset(data, group=\"sample_stats\").energy.transpose(\"chain\", \"draw\").values
111 if kind == \"histogram\":
112 warnings.warn(
113 \"kind histogram will be deprecated in a future release. Use `hist` \"
114 \"or set rcParam `plot.density_kind` to `hist`\",
115 FutureWarning,
116 )
File c:\\Users\\ck21395\\Anaconda3\\envs\\pymc_env2\\Lib\\site-packages\\xarray\\core\\dataarray.py:3019, in DataArray.transpose(self, transpose_coords, missing_dims, *dims)
2986 \"\"\"Return a new DataArray object with transposed dimensions.
2987
2988 Parameters
(...)
3016 Dataset.transpose
3017 \"\"\"
3018 if dims:
-> 3019 dims = tuple(infix_dims(dims, self.dims, missing_dims))
3020 variable = self.variable.transpose(*dims)
3021 if transpose_coords:
File c:\\Users\\ck21395\\Anaconda3\\envs\\pymc_env2\\Lib\\site-packages\\xarray\
amedarray\\utils.py:176, in infix_dims(dims_supplied, dims_all, missing_dims)
174 existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
175 if set(existing_dims) ^ set(dims_all):
--> 176 raise ValueError(
177 f\"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included\"
178 )
179 yield from existing_dims
ValueError: ('chain', 'draw') must be a permuted list of ('chain', 'draw', 'energy_dim_0'), unless `...` is included"
}
```

The relevant model block is:

```
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import xarray as xr
from pymc import HalfCauchy, Model, Normal, sample
print(f"Running on PyMC v{pm.__version__}")
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")
def main():
size = 200
true_intercept = 1
true_slope = 2
x = np.linspace(0, 1, size)
# y = a + b*x
true_regression_line = true_intercept + true_slope * x
# add noise
y = true_regression_line + rng.normal(scale=0.5, size=size)
data = pd.DataFrame(dict(x=x, y=y))
with Model() as model1: # model specifications in PyMC are wrapped in a with-statement
# Define priors
sigma = HalfCauchy("sigma", beta=10)
intercept = Normal("intercept", 0, sigma=20)
slope = Normal("slope", 0, sigma=20)
# Define likelihood
likelihood = Normal("y", mu=intercept + slope * x, sigma=sigma, observed=y)
with model1:
stepmethod=[pm.NUTS([model1["sigma"]]),pm.NUTS([model1["sigma"],model1["slope"],model1["intercept"]])]
step1=pm.NUTS([model1["sigma"]])
step2=pm.NUTS([model1["sigma"],model1["slope"],model1["intercept"]])
stepmethod=[step1,step2]
samples1=sample(1000,step=stepmethod)
#samples1=sample(1000)
az.plot_trace(data=samples1,var_names=["sigma","slope","intercept"],divergences=True)
az.plot_energy(data=samples1)
az.plot_pair(samples1, var_names=["sigma","slope","intercept"], divergences=True)
print(az.summary(data=samples1,var_names=["sigma","slope","intercept"]))
if __name__=='__main__':
main()
```

Find attached the full code file if wanted. Thanks in advance.

Customsampling.py (2.3 KB)