print(pm.__version__)
print(az.__version__)
4.0.0
0.12.1
Here are the code and longer error message:
Code:
def fn(a,b,c,x):
[...]
return y
with pm.Model() as model:
a = pm.TruncatedNormal('a', mu=1/50, sigma=1, lower=1/1000, upper=1, initval=1/50)
b = pm.Normal('b', 0, 10, initval=0)
c = pm.Normal('c', 0, 10, initval=0)
ϵ = pm.HalfCauchy('ϵ', 100, initval=1)
mu = fn(a, b, c, x)
pm.Potential("negative_penalty", pm.math.switch(mu<0, -np.inf, 0))
y_pred = pm.Gamma('y_pred', mu=mu, sigma=ϵ, observed=y_obs)
idata = pm.sample(draws=1000, tune=500, chains=2, target_accept = 0.999))
Check the idata variable:
idata
(Inference data with groups:
> posterior
> log_likelihood
> sample_stats
> observed_data,)
Error message:
az.plot_trace(idata, var_names=["a", "b", "c"])
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [107], in <cell line: 1>()
----> 1 az.plot_trace(idata, var_names=["a", "b"])
File /opt/anaconda3/envs/pymc/lib/python3.10/site-packages/arviz/plots/traceplot.py:194, in plot_trace(data, var_names, filter_vars, transform, coords, divergences, kind, figsize, rug, lines, circ_var_names, circ_var_units, compact, compact_prop, combined, chain_prop, legend, plot_kwargs, fill_kwargs, rug_kwargs, hist_kwargs, trace_kwargs, rank_kwargs, labeller, axes, backend, backend_config, backend_kwargs, show)
191 else:
192 divergence_data = False
--> 194 coords_data = get_coords(convert_to_dataset(data, group="posterior"), coords)
196 if transform is not None:
197 coords_data = transform(coords_data)
File /opt/anaconda3/envs/pymc/lib/python3.10/site-packages/arviz/data/converters.py:179, in convert_to_dataset(obj, group, coords, dims)
140 def convert_to_dataset(obj, *, group="posterior", coords=None, dims=None):
141 """Convert a supported object to an xarray dataset.
142
143 This function is idempotent, in that it will return xarray.Dataset functions
(...)
177 xarray.Dataset
178 """
--> 179 inference_data = convert_to_inference_data(obj, group=group, coords=coords, dims=dims)
180 dataset = getattr(inference_data, group, None)
181 if dataset is None:
File /opt/anaconda3/envs/pymc/lib/python3.10/site-packages/arviz/data/converters.py:131, in convert_to_inference_data(obj, group, coords, dims, **kwargs)
116 else:
117 allowable_types = (
118 "xarray dataarray",
119 "xarray dataset",
(...)
129 "cmdstanpy fit",
130 )
--> 131 raise ValueError(
132 "Can only convert {} to InferenceData, not {}".format(
133 ", ".join(allowable_types), obj.__class__.__name__
134 )
135 )
137 return InferenceData(**{group: dataset})
ValueError: Can only convert xarray dataarray, xarray dataset, dict, netcdf filename, numpy array, pystan fit, pymc3 trace, emcee fit, pyro mcmc fit, numpyro mcmc fit, cmdstan fit csv filename, cmdstanpy fit to InferenceData, not tuple