How to set prior in arviz InferenceData from pm.sample_prior_predictive()?

I would like to use az.plot_dist_comparison() to visually compare the prior and the posterior of a variable. az.plot_dist_comparison() uses the posterior samples and prior samples from an az.InferenceData object, passed as an argument. I have an az.InferenceData object, with posterior samples, created via a call to pm.sample(). And I have prior samples, created via a call to pm.sample_prior_predictive().

But the prior samples are not within the az.InferenceData object. Instead pm.sample_prior_predictive() returns the samples as a dictionary with variable names as keys and numpy arrays as values. How do I take that dict and add it as the prior on the az.InferenceData object? Or maybe there is a simpler path, some way to tell pm.sample_prior_predictive() to add the prior samples to an existing az.InferenceData object?

I think it’s not yet possible to have pm.sample_prior_predictive to add the samples to inferencedata directly, but it is on the roadmap (could even have a PR already opened for that?). Right now you should use from_pymc3 and extend, something like:

with model:
    prior = pm.sample_prior..
    idata.extend(az.from_pymc3(prior=prior))

az.plot_dist_comparison(idata)

That explains why I found no examples!