Maybe this will help somebody… here’s my approach for combining multiple trials as chains in an inference object. When plotted, the chains overlay on the same axis.
def create_chain_dict(trace_obj):
#vars()[chain_name]
chain_dict= {
"mu_0": trace_obj['mu_0'].reshape(1, -1, 3) , #1 chain, 1000 points, D = 3 dim
"mu_1": trace_obj['mu_1'].reshape(1, -1, 3) , #1 chain, 1000 points, D = 3 dim
"mu_2": trace_obj['mu_2'].reshape(1, -1, 3) , #1 chain, 1000 points, D = 3 dim
"sigma_0": trace_obj['sigma_0'].reshape(1, -1, 9), #1 chain, 1000 points, D = 9 dim
"sigma_1": trace_obj['sigma_1'].reshape(1, -1, 9), #1 chain, 1000 points, D = 9 dim
"sigma_2": trace_obj['sigma_2'].reshape(1, -1, 9), #1 chain, 1000 points, D = 9 dim
"p": trace_obj['p'].reshape(1, -1, 3),
}
return chain_dict
def create_inference_obj(trace_0, trace_1):
# Put data for each chain into dictionaries
chain_0= create_chain_dict(trace_0)
chain_1= create_chain_dict(trace_1)
# convert dictionaries into data inference objects
data_0 = az.from_dict(chain_0 )
data_1 = az.from_dict(chain_1 )
#combine data inference objects into a single data inference object w/ multiple chains
data_combined = az.concat(data_0, data_1, dim= "chain")
return data_combined
def plot_p(inf_obj, title_str):
title= title_str + "\n Trace Plots for p"
myfig = az.plot_trace(inf_obj, compact=False, show= False, var_names=['p'])
plt.suptitle(title, fontsize = 20) #only changes the last subplot
plt.show