Using the starts argument to sample_smc()

I’m working on the “streaming timeseries data” use case for SMC, where one first runs sample_smc() with the first data point, supplies the resulting posterior as starts for another run of sample_smc() with the first two data points, and so on until all data has been processed.

I’m getting stuck with use of the start argument to sample_smc(). I’ve worked out how to get output from sample_smc() that contains the transformed (aka “unconstrained”) variables, but I’ve tried:


with pm.Model() as model:
  …
  trace = pm.sample_smc(
    …
    , start = trace
  #  , start = trace[“posterior”]
  #  , start = model.rvs_to_initial_values(trace)
  #  , start = model.rvs_to_initial_values(trace[“posterior”])
  )

including each of the commented-out versions, and nothing seems to work. Help?

Solved it.

Key helper function is:

def extract_starting_points(trace):
    """
    Extract posterior samples from a PyMC3 SMC trace object to use as starting points for the next SMC run.
    
    Parameters:
    -----------
    trace: arviz.InferenceData
        The trace object returned by `pm.sample_smc()`.
    
    Returns:
    --------
    start_dict: dict
        A dictionary containing the starting points for each variable.
    """

    # return None if trace is None
    if trace is None:
        return None
    
    # Unstack the chains and draws
    posterior_samples = trace.posterior.unstack()
    
    # Initialize an empty dictionary to store the starting points for each variable
    start_dict = {}
    
    # Loop through each variable in the posterior samples
    for var in posterior_samples.data_vars:        
        # Extract the variable's values and store them in the dictionary
        values = posterior_samples[var].values
        start_dict[var] = values.reshape(-1, *values.shape[2:])
    
    return start_dict

And pm.sample_smc() needs the argument idata_kwargs={'include_transformed': True}

2 Likes