Sample_posterior_predictive slow performance due to dataset_to_point_list


I am trying to sample from my posterior predictive distribution, but sample_posterior_predictive is taking a very long time (a few minutes) before the progress bar even shows, and it is slow even when predicting one sample. After some code reading, I figured out that the dataset_to_point_list function in is likely to be causing the performance issue.

In, sample_posterior_predictive(),

elif isinstance(trace, xarray.Dataset):
    idata_kwargs["coords"].setdefault("draw", trace["draw"])
    idata_kwargs["coords"].setdefault("chain", trace["chain"])
    _trace = dataset_to_point_list(trace)
    nchain, len_trace = chains_and_samples(trace)

And dataset_to_point_list is being calculated using a triple for-loop, which can be quite slow.
In, dataset_to_point_list(),

for c in ds.chain:
    for d in ds.draw:
        points.append({vn: da.sel(chain=c, draw=d).values for vn, da in ds.items()})

My trace data is rather big,

  • 7000 draws per model parameter
  • 13 model parameters, each being 4 dimensional so 52 in total
  • I am saving my model trace into an .h5 file via trace.posterior.to_netcdf()
  • version = pymc 4.0.1

My questions,

  • What is the recommended way to save & load model trace?

    • It seems that dataset_to_point_list will not be called if
      • isinstance(trace, MultiTrace) - How do I save & load a MultiTrace object?
      • isinstance(trace, list) and all(isinstance(x, dict) for x in trace) - This seems to be the output of dataset_to_point_list.
    • I suppose in my case I could save & load dataset_to_point_list(trace) instead, but it loses the #chains information. Is there a better way?
  • Is there any particular reason why dataset_to_point_list is implemented using 3 for loops?

    • Will there be a more efficient version in the future?


1 Like

I think this will be improved in the future, ideally the posterior predictive sampling would broadcast along the chain and draw dimensions instead to needing to loop over them. I think @ricardoV94 and @lucianopaz have some ideas on this.

As a temporal workaround, and depending on your model, you can try using xarray-einstats to generate posterior predictive samples from the posterior ones. It might be faster and it is also compatible with Dask already, so it can handle arrays of posterior predictive samples that don’t fit in memory. It is however much less convenient than pm.sample_posterior_predictive. I have some examples of generating posterior predictive samples in CmdStanPy and ArviZ integration | Oriol unraveled. It uses cmdstanpy for posterior sampling, but once you get an az.InferenceData object, the computations no longer depend on the PPL used, so the posterior predictive sampling will look similar to your case.

Thank you!