Hi,
I am new to PyMC, and like to API a lot: I find it more intuitive than for example NumPyro. I am hitting a snag though at the moment when using az.compare in 4.2.0.
I created a minimal example below. I tried installing pymc 4.2.0 using both pip and anaconda, and tried to add log likelihood to the inference data separately.
How could I make this work?
Cheers,
Sanne
def setup_model(x, y, intercept=True):
model = pm.Model()
with model:
x = pm.ConstantData('x', x)
y = pm.ConstantData('y', y)
if intercept: α = pm.Normal("α", 0, 1)
else: α = 0
β = pm.Normal("β", 0, 1)
σ = pm.HalfNormal("σ", 1)
μ = pm.Deterministic('μ', α + β * x)
obs = pm.Normal("obs", μ, σ, observed=y)
return model
x = 10 * np.random.uniform(size=100) - 5
e = np.random.normal(0, 1, size=100)
y = 2 * x + e
trace_1 = pm.sample(model=setup_model(x, y, True))
trace_2 = pm.sample(model=setup_model(x, y, False))
df_comp_loo = az.compare({"1": trace_1, "2": trace_2})
df_comp_loo
The error is:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/miniforge3/envs/pymc-420/lib/python3.10/site-packages/pandas/core/indexes/base.py:3800, in Index.get_loc(self, key, method, tolerance)
3799 try:
-> 3800 return self._engine.get_loc(casted_key)
3801 except KeyError as err:
File ~/miniforge3/envs/pymc-420/lib/python3.10/site-packages/pandas/_libs/index.pyx:138, in pandas._libs.index.IndexEngine.get_loc()
File ~/miniforge3/envs/pymc-420/lib/python3.10/site-packages/pandas/_libs/index.pyx:144, in pandas._libs.index.IndexEngine.get_loc()
TypeError: 'slice(None, None, None)' is an invalid key
During handling of the above exception, another exception occurred:
InvalidIndexError Traceback (most recent call last)
Cell In [7], line 32
29 trace2 = pm.sample(1000, return_inferencedata=False)
30 idata2 = pm.to_inference_data(trace=trace2, log_likelihood=True)
---> 32 df_comp_loo = az.compare({"1": idata2, "2": idata2})
33 df_comp_loo
File ~/miniforge3/envs/pymc-420/lib/python3.10/site-packages/arviz/stats/stats.py:306, in compare(compare_dict, ic, method, b_samples, alpha, seed, scale, var_name)
304 std_err = ses.loc[val]
305 weight = weights[idx]
--> 306 df_comp.at[val] = (
307 idx,
308 res[ic],
309 res[p_ic],
310 d_ic,
311 weight,
312 std_err,
313 d_std_err,
314 res["warning"],
315 res[scale_col],
316 )
318 df_comp["rank"] = df_comp["rank"].astype(int)
319 df_comp["warning"] = df_comp["warning"].astype(bool)
File ~/miniforge3/envs/pymc-420/lib/python3.10/site-packages/pandas/core/indexing.py:2438, in _AtIndexer.__setitem__(self, key, value)
2435 self.obj.loc[key] = value
2436 return
-> 2438 return super().__setitem__(key, value)
File ~/miniforge3/envs/pymc-420/lib/python3.10/site-packages/pandas/core/indexing.py:2393, in _ScalarAccessIndexer.__setitem__(self, key, value)
2390 if len(key) != self.ndim:
2391 raise ValueError("Not enough indexers for scalar access (setting)!")
-> 2393 self.obj._set_value(*key, value=value, takeable=self._takeable)
File ~/miniforge3/envs/pymc-420/lib/python3.10/site-packages/pandas/core/frame.py:4208, in DataFrame._set_value(self, index, col, value, takeable)
4206 iindex = cast(int, index)
4207 else:
-> 4208 icol = self.columns.get_loc(col)
4209 iindex = self.index.get_loc(index)
4210 self._mgr.column_setitem(icol, iindex, value)
File ~/miniforge3/envs/pymc-420/lib/python3.10/site-packages/pandas/core/indexes/base.py:3807, in Index.get_loc(self, key, method, tolerance)
3802 raise KeyError(key) from err
3803 except TypeError:
3804 # If we have a listlike key, _check_indexing_error will raise
3805 # InvalidIndexError. Otherwise we fall through and re-raise
3806 # the TypeError.
-> 3807 self._check_indexing_error(key)
3808 raise
3810 # GH#42269
File ~/miniforge3/envs/pymc-420/lib/python3.10/site-packages/pandas/core/indexes/base.py:5963, in Index._check_indexing_error(self, key)
5959 def _check_indexing_error(self, key):
5960 if not is_scalar(key):
5961 # if key is not a scalar, directly raise an error (the code below
5962 # would convert to numpy arrays and raise later any way) - GH29926
-> 5963 raise InvalidIndexError(key)
InvalidIndexError: slice(None, None, None)