I’m working on the “streaming timeseries data” use case for SMC, where one first runs sample_smc()
with the first data point, supplies the resulting posterior as starts for another run of sample_smc()
with the first two data points, and so on until all data has been processed.
I’m getting stuck with use of the start
argument to sample_smc()
. I’ve worked out how to get output from sample_smc()
that contains the transformed (aka “unconstrained”) variables, but I’ve tried:
with pm.Model() as model:
…
trace = pm.sample_smc(
…
, start = trace
# , start = trace[“posterior”]
# , start = model.rvs_to_initial_values(trace)
# , start = model.rvs_to_initial_values(trace[“posterior”])
)
including each of the commented-out versions, and nothing seems to work. Help?
Solved it.
Key helper function is:
def extract_starting_points(trace):
"""
Extract posterior samples from a PyMC3 SMC trace object to use as starting points for the next SMC run.
Parameters:
-----------
trace: arviz.InferenceData
The trace object returned by `pm.sample_smc()`.
Returns:
--------
start_dict: dict
A dictionary containing the starting points for each variable.
"""
# return None if trace is None
if trace is None:
return None
# Unstack the chains and draws
posterior_samples = trace.posterior.unstack()
# Initialize an empty dictionary to store the starting points for each variable
start_dict = {}
# Loop through each variable in the posterior samples
for var in posterior_samples.data_vars:
# Extract the variable's values and store them in the dictionary
values = posterior_samples[var].values
start_dict[var] = values.reshape(-1, *values.shape[2:])
return start_dict
And pm.sample_smc()
needs the argument idata_kwargs={'include_transformed': True}
2 Likes