Arviz plot_energy doesn't like PyMC custom step method?

So I’ve been dealing with a model that I recently discovered is highly degenerate and was generating lots of divergences and taking forever to run, so I decided to implement a custom step method which resolved the divergences part of the problem but throws an error when I attempt to call the plot_energy function from Arviz on the inference data. I have replicated this in a more minimal example:

{
	"name": "ValueError",
	"message": "('chain', 'draw') must be a permuted list of ('chain', 'draw', 'energy_dim_0'), unless `...` is included",
	"stack": "---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File c:\\Users\\ck21395\\PhD codes\\Neville_codes\\HMC Investigations\\Customsampling.py:305
    302     print(az.summary(data=samples1,var_names=[\"sigma\",\"slope\",\"intercept\"]))
    304 if __name__=='__main__':
--> 305     main()

File c:\\Users\\ck21395\\PhD codes\\Neville_codes\\HMC Investigations\\Customsampling.py:300
    297 print(samples1['posterior'])
    299 az.plot_trace(data=samples1,var_names=[\"sigma\",\"slope\",\"intercept\"],divergences=True)
--> 300 az.plot_energy(data=samples1)
    301 az.plot_pair(samples1, var_names=[\"sigma\",\"slope\",\"intercept\"], divergences=True)
    302 print(az.summary(data=samples1,var_names=[\"sigma\",\"slope\",\"intercept\"]))

File c:\\Users\\ck21395\\Anaconda3\\envs\\pymc_env2\\Lib\\site-packages\\arviz\\plots\\energyplot.py:109, in plot_energy(data, kind, bfmi, figsize, legend, fill_alpha, fill_color, bw, textsize, fill_kwargs, plot_kwargs, ax, backend, backend_kwargs, show)
      9 def plot_energy(
     10     data,
     11     kind=None,
   (...)
     24     show=None,
     25 ):
     26     r\"\"\"Plot energy transition distribution and marginal energy distribution in HMC algorithms.
     27 
     28     This may help to diagnose poor exploration by gradient-based algorithms like HMC or NUTS.
   (...)
    107 
    108     \"\"\"
--> 109     energy = convert_to_dataset(data, group=\"sample_stats\").energy.transpose(\"chain\", \"draw\").values
    111     if kind == \"histogram\":
    112         warnings.warn(
    113             \"kind histogram will be deprecated in a future release. Use `hist` \"
    114             \"or set rcParam `plot.density_kind` to `hist`\",
    115             FutureWarning,
    116         )

File c:\\Users\\ck21395\\Anaconda3\\envs\\pymc_env2\\Lib\\site-packages\\xarray\\core\\dataarray.py:3019, in DataArray.transpose(self, transpose_coords, missing_dims, *dims)
   2986 \"\"\"Return a new DataArray object with transposed dimensions.
   2987 
   2988 Parameters
   (...)
   3016 Dataset.transpose
   3017 \"\"\"
   3018 if dims:
-> 3019     dims = tuple(infix_dims(dims, self.dims, missing_dims))
   3020 variable = self.variable.transpose(*dims)
   3021 if transpose_coords:

File c:\\Users\\ck21395\\Anaconda3\\envs\\pymc_env2\\Lib\\site-packages\\xarray\
amedarray\\utils.py:176, in infix_dims(dims_supplied, dims_all, missing_dims)
    174 existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
    175 if set(existing_dims) ^ set(dims_all):
--> 176     raise ValueError(
    177         f\"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included\"
    178     )
    179 yield from existing_dims

ValueError: ('chain', 'draw') must be a permuted list of ('chain', 'draw', 'energy_dim_0'), unless `...` is included"
}

The relevant model block is:

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import xarray as xr

from pymc import HalfCauchy, Model, Normal, sample

print(f"Running on PyMC v{pm.__version__}")
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)


az.style.use("arviz-darkgrid")

def main():
    size = 200
    true_intercept = 1
    true_slope = 2

    x = np.linspace(0, 1, size)
    # y = a + b*x
    true_regression_line = true_intercept + true_slope * x
    # add noise
    y = true_regression_line + rng.normal(scale=0.5, size=size)

    data = pd.DataFrame(dict(x=x, y=y))
    
    with Model() as model1:  # model specifications in PyMC are wrapped in a with-statement
        # Define priors
        sigma = HalfCauchy("sigma", beta=10)
        intercept = Normal("intercept", 0, sigma=20)
        slope = Normal("slope", 0, sigma=20)

        # Define likelihood
        likelihood = Normal("y", mu=intercept + slope * x, sigma=sigma, observed=y)
    
    with model1:
        stepmethod=[pm.NUTS([model1["sigma"]]),pm.NUTS([model1["sigma"],model1["slope"],model1["intercept"]])]
        step1=pm.NUTS([model1["sigma"]])
        step2=pm.NUTS([model1["sigma"],model1["slope"],model1["intercept"]])
        stepmethod=[step1,step2]
        samples1=sample(1000,step=stepmethod)
        #samples1=sample(1000)
   
    az.plot_trace(data=samples1,var_names=["sigma","slope","intercept"],divergences=True)
    az.plot_energy(data=samples1)
    az.plot_pair(samples1, var_names=["sigma","slope","intercept"], divergences=True)
    print(az.summary(data=samples1,var_names=["sigma","slope","intercept"]))
    
if __name__=='__main__':
    main()

Find attached the full code file if wanted. Thanks in advance.
Customsampling.py (2.3 KB)

Which version are you on? I am on pymc 5.10.0 and the only error I get is step related and I assume this is because you set the step for sigma twice. Once I change it to

stepmethod = pm.NUTS([model1["sigma"],model1["slope"],model1["intercept"]])
samples1=sample(1000,step=stepmethod)

or let pymc choose the steps, it runs fine:



slope and intercept is correlated probably because data is not standardised

Thanks for replying!

I am on v5.11.0 of PyMC, so it could either be a bug or the Python gods screwing with me again :sweat_smile:

I have ruled out (at least on my machine that it isn’t the duplicate sampling of the same variable doing it since the following code still throws the error):

Running the default step method or the same one you mentioned leads to no issues, which thus leads me to believe then for whatever reason having multiple steps in the array like stepmethod=[step1,step2] is what leads to the issue but then if it isn’t causing issues on v5.10.0 then perhaps it is a bug of sorts, since I can get the trace_plot out fine but the arviz diagnostic plots don’t like it, since the plot_pair function function throws its own error out:

The conda environment I’m working on should be clean so I will try rolling back to v5.10.0 and that should illustrate if it is a v5.11.0 bug or a bug with my own conda environment.

Slopes and intercepts have to be correlated, because of the geometry of a line, regardless of centering or scaling.

Not when the covariates are standardised though (if you do with the above model x = (x - x.mean())/x.std():

At least in the only likelihood approach you can show that the correlation matrix between the slope and intercept has off diagonal terms that equal to -\frac{1}{n}\sum_i x_i, which when centered is 0. From a geometric perspective if your line is passing through the origin, a change in slope does not result on change in intercept unlike a line not passing through the origin.

1 Like

This also happens to me on 5.10 as well. I suspect you should group all NUTs step into one group rather than two separate groups. I don’t know if this is a design choice or a bug I will let a dev comment on that but what you want can be achieved via

step=pm.NUTS([model1["sigma"], model1["slope"],model1["intercept"]])

Or are you trying to achieve something different than this?

1 Like

Yeah I was doing it hierarchically in that fashion since my original model that isn’t the working example I provided was quite degenerate/non-identifiable meaning there were many different optimal configurations which was screwing up the HMC sampling and so I tried to reduce the configuration space in a manner of speaking by first sampling one free variable, then three free variables and then finally all free variables. At least that is my understanding of that specific problem which was why I tried that approach which appeared to work apart from this hitch of screwing with arviz diagnostic plots.

I got the same error that I think you are mentioning which reminded me as to why I rolled forward to 5.11.0 in the first place is the truth value error:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

which following some investigation was discovered and resolved in:

So in summary it may be a bug then that custom step_methods expressed as [step1,step2…etc] throw this error and it may be misapplication of the argument or a bug.

1 Like

Ok got it now. Did not know compounding steps in different ways could be used as such, thanks for the explanation. When I compound the steps the way you did, the error I get is also the same ValueError. This looks like a dev level kind of question so I will let someone else answer.

1 Like

The arviz energy function seems to not handle traces from multiple NUTS. I am not sure the statistics it is computing make sense with multiple NUTS samplers, perhaps someone like @OriolAbril or @colcarroll will be able to chime in.

1 Like

Also not sure how multiple samplers should be handled. I’d have to review the paper in depth which I won’t be able to anytime soon. I see basically two options though, that could be made to work with current ArviZ implementation with the following workarounds:

Generate one energy plot per sampler

If the way to go were to treat each NUTS as independent, you’d then have to generate an energy subplot for each, looping over this extra energy_dim_0. This could potentially be added as a feature to ArviZ. But for now you’ll have to handle it manually.
It will be something like:

stepmethods = ("NUTS(sigma)", "NUTS(sigma, slope, intercept")
_, axes = plt.subplots(1, len(stepmethods))
for i, sampler in enumerate(stepmethods):
    az.plot_energy(samples.sample_stats.isel(energy_dim_0=i), ax=axes[i])
    axes[i].set_title(sampler)

Combine the energy info of the multiple samplers

If the way to go were to combine the energy info of each step method into a global quantity it would be something like:

sample_stats = samples1.sample_stats
sample_stats["energy_steps"] = sample_stats["energy"]
del sample_stats["energy"]
# extra assumption, the combining of energy info is a sum
sample_stats["energy"] = sample_stats["energy_steps"].sum("energy_dim_0")
az.plot_energy(samples1)

If the way to combine the energy info of the different step methods is something that is clear and independent of sampler that could also be added to plot_energy code itself. Otherwise it’d have to be a note/example on the docs so users combine that info themselves depending on the samplers then pass that to plot_energy

3 Likes

Thanks @OriolAbril, I will probably lean more towards the former option of diagnosing each nuts sampler independently but I may try out both codes provided and have a bit of a read/play myself to see about how the energy information of multiple samplers is added together like you point out.

1 Like