Plotting Sample_Posterior_Predictive output with Arviz

Before pymc4 was released, I used to use Arviz to plot the output of sample_posterior_predictive. Specifically, the documentation mentioned to use arviz.from_pymc3(). Unfortunately, this method no longer exists, and I have not been able to find what the alternative is in the documentation.

The posterior predictive sample is now generated as an InferenceData object. Assuming that the InferenceData object would be compatible with Arviz plotting methods, I had assumed the following would work:

posterior = pm.sample_posterior_predictive(trace, model=model, return_inferencedata=True)

to which I receive the error KeyError: 0.

Could someone please help me find out what is the correct way to plot the sample posterior predictive output? Really appreciate the help!

1 Like

In general, I think you call to pm.sample() and get an inferenceData object back:

with model:
    idata = pm.sample()`

And then you call pm.sample_posterior_predictive() and ask the posterior predictive samples to be added to the existing inferenceData object:

with model:
    pm.sample_posterior_predictive(idata, extend_inferencedata=True)

I think az.plot_ppc() plots both the posterior predictive and the observed data. The observed data is saved into the inferenceData object when you call pm.sample(), but if you just do this

posterior = pm.sample_posterior_predictive(trace, model=model, return_inferencedata=True)

then I think you will only have the posterior predictive samples in posterior and az.plot_ppc() will choke.

1 Like

You might also want to take a look at Prior and Posterior Predictive Checks — PyMC 5.0.2 documentation

Hi, thanks for your response, I have tried the following as you suggested:

with model:
        idata = pm.sample(num_samples, cores=1)
        posterior = pm.sample_posterior_predictive(idata, extend_inferencedata=True)

Unfortunately I still get the error KeyError: 0 thrown from the az.plot_ppc method.

what arviz and pymc versions are you using? Also, could you copy the output of posterior.posterior_predictive and posterior.observed_data here in a code block? (use print() if you get an html repr)

Can you share the complete error traceback?

Hi @OriolAbril, I made a mistake as the KeyError: 0 I reported was not being thrown from the az.plot_ppc method but from later down in my code. What confused me was that the call to az.plot_ppc(posterior) did not show the plot automatically, it was necessary to call it like this:

az.plot_ppc(posterior, show=True)

After adding the show argument I am able to see the plot, and can confirm that the code snippet I posted above does work! Thanks for the help.