How is `merge_traces` to be used?

Maybe this will help somebody… here’s my approach for combining multiple trials as chains in an inference object. When plotted, the chains overlay on the same axis.

def create_chain_dict(trace_obj):
    #vars()[chain_name] 
    chain_dict= {
    "mu_0": trace_obj['mu_0'].reshape(1, -1, 3) , #1 chain, 1000 points, D = 3 dim
    "mu_1": trace_obj['mu_1'].reshape(1, -1, 3) , #1 chain, 1000 points, D = 3 dim
    "mu_2": trace_obj['mu_2'].reshape(1, -1, 3) , #1 chain, 1000 points, D = 3 dim
    "sigma_0": trace_obj['sigma_0'].reshape(1, -1, 9), #1 chain, 1000 points, D = 9 dim
    "sigma_1": trace_obj['sigma_1'].reshape(1, -1, 9), #1 chain, 1000 points, D = 9 dim
    "sigma_2": trace_obj['sigma_2'].reshape(1, -1, 9), #1 chain, 1000 points, D = 9 dim
    "p": trace_obj['p'].reshape(1, -1, 3),
    }
    return chain_dict

def create_inference_obj(trace_0, trace_1):
    # Put data for each chain into dictionaries
    chain_0= create_chain_dict(trace_0)
    chain_1= create_chain_dict(trace_1)
    # convert dictionaries into data inference objects
    data_0 = az.from_dict(chain_0 )
    data_1 = az.from_dict(chain_1 )

    #combine data inference objects into a single data inference object w/ multiple chains
    data_combined = az.concat(data_0, data_1, dim= "chain")
    return data_combined


def plot_p(inf_obj, title_str):
    title= title_str + "\n Trace Plots for p"
    myfig = az.plot_trace(inf_obj, compact=False, show= False, var_names=['p'])    
    plt.suptitle(title, fontsize = 20) #only changes the last subplot
    plt.show
1 Like