Hello,
I am trying to sample from my posterior predictive distribution, but sample_posterior_predictive is taking a very long time (a few minutes) before the progress bar even shows, and it is slow even when predicting one sample. After some code reading, I figured out that the dataset_to_point_list function in util.py is likely to be causing the performance issue.
In sampling.py, sample_posterior_predictive(),
elif isinstance(trace, xarray.Dataset):
idata_kwargs["coords"].setdefault("draw", trace["draw"])
idata_kwargs["coords"].setdefault("chain", trace["chain"])
_trace = dataset_to_point_list(trace)
nchain, len_trace = chains_and_samples(trace)
And dataset_to_point_list is being calculated using a triple for-loop, which can be quite slow.
In utils.py, dataset_to_point_list(),
for c in ds.chain:
for d in ds.draw:
points.append({vn: da.sel(chain=c, draw=d).values for vn, da in ds.items()})
My trace data is rather big,
- 7000 draws per model parameter
- 13 model parameters, each being 4 dimensional so 52 in total
- I am saving my model trace into an .h5 file via trace.posterior.to_netcdf()
- version = pymc 4.0.1
My questions,
-
What is the recommended way to save & load model trace?
- It seems that dataset_to_point_list will not be called if
- isinstance(trace, MultiTrace) - How do I save & load a MultiTrace object?
or - isinstance(trace, list) and all(isinstance(x, dict) for x in trace) - This seems to be the output of dataset_to_point_list.
- isinstance(trace, MultiTrace) - How do I save & load a MultiTrace object?
- I suppose in my case I could save & load dataset_to_point_list(trace) instead, but it loses the #chains information. Is there a better way?
- It seems that dataset_to_point_list will not be called if
-
Is there any particular reason why dataset_to_point_list is implemented using 3 for loops?
- Will there be a more efficient version in the future?
Thanks!