It goes through each chain one at a time. However, you are not forced to supply a MultiTrace
instance, you can supply a list of point dictionaries. You can get those from the trace easily doing the following:
with pm.Model():
...
trace = pm.sample()
df = pm.trace_to_dataframe(trace,
varnames=[the variables you want],
include_transformed=True)
# We have to supply the samples kwarg because it cannot be inferred if the
# input trace is not a MultiTrace instance
ppc = pm.sample_posterior_predictive(trace=df.to_dict('records'),
samples=len(df))
As you can see. We first get a dataframe and convert it to records dict. At the dataframe level, you can do any indexing or chain manipulation you want or need before supplying it the sample_posterior_predictive
.
A related question was asked here.