How to make InferenceData returned by sample() aware of the prior and posterior_predictive

In the arviz docs, there is this example using from_pymc3:

trace = pm.sample(draws, chains=chains)
prior = pm.sample_prior_predictive()
posterior_predictive = pm.sample_posterior_predictive(trace)

pm_data = az.from_pymc3(
    trace=trace,
    prior=prior,
    posterior_predictive=posterior_predictive,
    coords={"school": np.arange(eight_school_data["J"])},
    dims={"theta": ["school"], "theta_tilde": ["school"]},
)

so pm_data is an InferenceData that is aware of the trace, prior, and posterior_predictive.

But soon pm.sample will return an InferenceData by default. How do I make that InferenceData aware of the prior and posterior_predictive?

2 Likes

There is this example in the “Example of InferenceData schema in PyMC3” guide from ArviZ:

dims_pred={
    "slack_comments": ["candidate developer"],
    "github_commits": ["candidate developer"],
    "time_since_joined": ["candidate developer"],
}
with model:
    pm.set_data({"time_since_joined": candidate_devs_time})
    predictions = pm.sample_posterior_predictive(trace)
    az.from_pymc3_predictions(
        predictions, 
        idata_orig=idata_pymc3, 
        inplace=True,
        coords={"candidate developer": candidate_devs},
        dims=dims_pred,
    )

However, this function adds a new property to the InferenceData object called predictions_constant_data and doesn’t change the posterior predictive section.
Therefore, this function doesn’t solve this problem.

Another option is to use the InferenceData.add_groups() method, but this feels hacky. My main concern would be that I am not following standard operating procedures for ArviZ and that will have annoying/misleading results later. Hopefully an ArviZ dev will chime in.

nb_trace.add_groups({"posterior_predictive": daysabs_post_pred})

Edits

  1. Edited to say that the function does not solve this specific problem.
  2. Add .add_groups() method.
2 Likes

I’d recommend taking a look at A Hierarchical model for Rugby prediction — PyMC3 documentation and A Primer on Bayesian Methods for Multilevel Modeling — PyMC3 documentation notebooks

1 Like

Thanks! It looks like extend and and the idata_orig argument to arviz.from_pymc3_predictions will do what I’m looking for.

1 Like