Hi PyMC3 community,
I ran into the following issue with PyMC 3.8: trace_to_dataframe
raises a KeyError
if called on a trace that had a RV removed via remove_values
. Here is a MWE:
import pymc3 as pm
import numpy as np
data = np.random.normal(loc=2,scale=0.5,size=10)
with pm.Model() as m:
mu = pm.Normal("mu",mu=3,sigma=1)
sig = pm.InverseGamma("sigma",alpha=1,beta=1)
pm.Normal("d",mu=mu,sigma=sig,observed=data)
trace = pm.sample(model=m,chains=2,cores=1)
trace.remove_values("sigma")
pm.summary(trace)
df = pm.trace_to_dataframe(trace)
df.head()
The error is
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Untitled-1 in
14 trace.remove_values("sigma")
15 pm.summary(trace)
---> 16 df = pm.trace_to_dataframe(trace)
17 df.head()
~\miniconda3\envs\dive\lib\site-packages\pymc3\backends\tracetab.py in trace_to_dataframe(trace, chains, varnames, include_transformed)
36 var_dfs = []
37 for v in varnames:
---> 38 vals = trace.get_values(v, combine=True, chains=chains)
39 flat_vals = vals.reshape(vals.shape[0], -1)
40 var_dfs.append(pd.DataFrame(flat_vals, columns=flat_names[v]))
~\miniconda3\envs\dive\lib\site-packages\pymc3\backends\base.py in get_values(self, varname, burn, thin, combine, chains, squeeze)
471 try:
472 results = [self._straces[chain].get_values(varname, burn, thin)
--> 473 for chain in chains]
474 except TypeError: # Single chain passed.
475 results = [self._straces[chains].get_values(varname, burn, thin)]
~\miniconda3\envs\dive\lib\site-packages\pymc3\backends\base.py in (.0)
471 try:
472 results = [self._straces[chain].get_values(varname, burn, thin)
--> 473 for chain in chains]
474 except TypeError: # Single chain passed.
475 results = [self._straces[chains].get_values(varname, burn, thin)]
~\miniconda3\envs\dive\lib\site-packages\pymc3\backends\ndarray.py in get_values(self, varname, burn, thin)
286 A NumPy array
287 """
--> 288 return self.samples[varname][burn::thin]
289
290 def _slice(self, idx):
KeyError: 'sigma'
Is there a workaround for this? I also see that trace_to_dataframe
is going to be deprecated and removed (issue #3907). What should I use instead to convert a trace to a Pandas dataframe?
Thanks!