Pm.sample result does not work with az.plot_trace or other az functions

I got the idata from pm.sample(), and tried to draw a trace plot using az. However, it shows the following error.

Can only convert xarray dataarray, xarray dataset, dict, netcdf filename, numpy array, pystan fit, pymc3 trace, emcee fit, pyro mcmc fit, numpyro mcmc fit, cmdstan fit csv filename, cmdstanpy fit to InferenceData, not tuple

when I check the data, it shows:
(Inference data with groups:
> posterior
> log_likelihood
> sample_stats
> observed_data,)

In addition, when do pm.plot_posterior(idata), it shows the same error.

ValueError: Can only convert xarray dataarray, xarray dataset, dict, netcdf filename, numpy array, pystan fit, pymc3 trace, emcee fit, pyro mcmc fit, numpyro mcmc fit, cmdstan fit csv filename, cmdstanpy fit to InferenceData, not tuple

I am using pymc version 4.0.0.

Please advise!

You seem to have a basic error with passing a tuple instead of an InferenceData, can you share the code and /or a longer error message?

1 Like
print(pm.__version__)
print(az.__version__)
4.0.0
0.12.1

Here are the code and longer error message:

Code:

def fn(a,b,c,x):
      [...]
      return y

with pm.Model() as model:
    
    a = pm.TruncatedNormal('a', mu=1/50, sigma=1, lower=1/1000, upper=1, initval=1/50)
    b = pm.Normal('b', 0, 10, initval=0)
    c = pm.Normal('c', 0, 10, initval=0)
    ϵ  = pm.HalfCauchy('ϵ', 100, initval=1)
    mu = fn(a, b, c, x)
    pm.Potential("negative_penalty", pm.math.switch(mu<0, -np.inf, 0))
    
    y_pred = pm.Gamma('y_pred', mu=mu, sigma=ϵ, observed=y_obs)

    idata = pm.sample(draws=1000, tune=500, chains=2, target_accept = 0.999))

Check the idata variable:

idata
(Inference data with groups:
 	> posterior
 	> log_likelihood
 	> sample_stats
 	> observed_data,)

Error message:

az.plot_trace(idata, var_names=["a", "b", "c"])

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [107], in <cell line: 1>()
----> 1 az.plot_trace(idata, var_names=["a", "b"])

File /opt/anaconda3/envs/pymc/lib/python3.10/site-packages/arviz/plots/traceplot.py:194, in plot_trace(data, var_names, filter_vars, transform, coords, divergences, kind, figsize, rug, lines, circ_var_names, circ_var_units, compact, compact_prop, combined, chain_prop, legend, plot_kwargs, fill_kwargs, rug_kwargs, hist_kwargs, trace_kwargs, rank_kwargs, labeller, axes, backend, backend_config, backend_kwargs, show)
    191 else:
    192     divergence_data = False
--> 194 coords_data = get_coords(convert_to_dataset(data, group="posterior"), coords)
    196 if transform is not None:
    197     coords_data = transform(coords_data)

File /opt/anaconda3/envs/pymc/lib/python3.10/site-packages/arviz/data/converters.py:179, in convert_to_dataset(obj, group, coords, dims)
    140 def convert_to_dataset(obj, *, group="posterior", coords=None, dims=None):
    141     """Convert a supported object to an xarray dataset.
    142 
    143     This function is idempotent, in that it will return xarray.Dataset functions
   (...)
    177     xarray.Dataset
    178     """
--> 179     inference_data = convert_to_inference_data(obj, group=group, coords=coords, dims=dims)
    180     dataset = getattr(inference_data, group, None)
    181     if dataset is None:

File /opt/anaconda3/envs/pymc/lib/python3.10/site-packages/arviz/data/converters.py:131, in convert_to_inference_data(obj, group, coords, dims, **kwargs)
    116 else:
    117     allowable_types = (
    118         "xarray dataarray",
    119         "xarray dataset",
   (...)
    129         "cmdstanpy fit",
    130     )
--> 131     raise ValueError(
    132         "Can only convert {} to InferenceData, not {}".format(
    133             ", ".join(allowable_types), obj.__class__.__name__
    134         )
    135     )
    137 return InferenceData(**{group: dataset})

ValueError: Can only convert xarray dataarray, xarray dataset, dict, netcdf filename, numpy array, pystan fit, pymc3 trace, emcee fit, pyro mcmc fit, numpyro mcmc fit, cmdstan fit csv filename, cmdstanpy fit to InferenceData, not tuple

Not sure how that happened as your code shouldn’t have produced such a result. Somehow your idata object is wrapped in a tuple (notice the parentheses around the result). When I run a simplified version of you model, this is what I get:

In [3]: idata
Out[3]: 
Inference data with groups:
	> posterior
	> log_likelihood
	> sample_stats
	> observed_data

Yes, this is very weird. Why my idata comes with tuple not the arviz.InferenceData? When I do

pm.plot_forest(idata)

It works, but no other pm.plot_* function.

I tried with pymc3 version, but the same issue is still there.
Any other suggestion?

Thank you.

I would try to modify your model until you figure out what’s going on. Here’s the simplified version that I ran:

import pymc as pm

with pm.Model() as model:
    
    a = pm.TruncatedNormal('a', mu=1/50, sigma=1, lower=1/1000, upper=1, initval=1/50)
    b = pm.Normal('b', 0, 10, initval=0)
    c = pm.Normal('c', 0, 10, initval=0)
    ϵ  = pm.HalfCauchy('ϵ', 100, initval=1)
    #mu = fn(a, b, c, x)
    #pm.Potential("negative_penalty", pm.math.switch(mu<0, -np.inf, 0))
    mu = 1
    
    y_pred = pm.Gamma('y_pred', mu=mu, sigma=ϵ, observed=[1,2,3])

    idata = pm.sample()

Make sure there is no comma after the pm.sample call because that would create a tuple:

In [1]: import arviz as az
idata = 
In [2]: idata = az.load_arviz_data("rugby"),

In [3]: idata
Out[3]: 
(Inference data with groups:
 	> posterior
 	> posterior_predictive
 	> sample_stats
 	> prior
 	> observed_data,)

Note the InferenceData text is inside some parenthesis. The comma after the load_arviz_data puts the InferenceData returned by the function inside a tuple. On the other hand:

In [4]: idata = az.load_arviz_data("rugby")

In [5]: idata
Out[5]: 
Inference data with groups:
	> posterior
	> posterior_predictive
	> sample_stats
	> prior
	> observed_data

Returns the inferencedata without wrapping it inside a tuple. I know I am not using pm.sample but I am using a function that returns an inferencedata object, so I am quite sure that for this specific issue both are interchangeable

1 Like

@OriolAbril You pinned it! I accidently put (,) after pm.sample().
I should have put (,) after # when commented out other arguments.
Thank you for time and help! (@ricardoV94 and @cluhmann as well!).

1 Like

Side note, plot_forest and plot_density will work in both cases because they both can take an inferencedata object or an iterable of inferencedata objects, and a tuple of length 1 with an inferencedata inside is an iterable of inferencedata objects.

1 Like

Yes! I was about to ask this. Your explanation actually explains everything now.