How is `merge_traces` to be used?

I have multiple different trained traces for a single model, and I was looking at merging them.
But I’m puzzled by how merge_traces is supposed to be used. Its doc string says:

mtraces : list of MultiTraces
        Each instance should have unique chain numbers.

How is this to be achieved? AFAICT if you have two MultiTraces, they are both guaranteed to (at least) have a chain zero.
Shouldn’t this function be provided with code that will “uniquify” the chain indices across the MultiTraces? This seems like a tricky thing to be left as an exercise for the reader.
Thanks

I think this is for when you have multiple single chain trace object, and you use this function to merge it into a MultiTraces. But I never actually use it so my memory might be off.

To add some information here, you would have to specify a unique chain index for each chain when calling pm.sample. This is relatively easy, as you could do something like this:
chains = [pm.sample(chain_idx=i) for i in range(n_chains)]

However, even if you do this, you cannot merge them currently as you will receive an error about trying to set an attribute:

i~/anaconda3/lib/python3.6/site-packages/pymc3/backends/base.py in merge_traces(mtraces)
550                 raise ValueError("Chains are not unique.")
551             base_mtrace._straces[new_chain] = strace
552     base_mtrace.report = merge_reports([trace.report for trace in mtraces])
553     return base_mtrace
554 

AttributeError: can't set attribute

I’m wondering what this is for. I don’t see any calls to it in the pymc3 codebase and it seems like it cannot merge MultiTrace objects because it doesn’t try to renumber chains.

Is there any record of it being used? Should it simply be removed? Or should it be extended to be able to renumber chains, thus making it more useful?

1 Like

This function would be especially important for complicated hierarchical models such that multiple chains cannot be produced without error. We need to merge multi-trace objects to perform prior/posterior checks, or even to plot overlaid trace plots.
Has anybody recently learned to use the merge_trace function?

I’d recommend using InferenceData and az.concat instead of using multitraces.

3 Likes

OriolAbril, thank you for the response. Overlaying the chains w/ az.concat is enabling me to graphically check convergence. It was a charlie foxtrot when I tried to compare 4+ pages of single-chain trace plots and distributions.

1 Like

Maybe this will help somebody… here’s my approach for combining multiple trials as chains in an inference object. When plotted, the chains overlay on the same axis.

def create_chain_dict(trace_obj):
    #vars()[chain_name] 
    chain_dict= {
    "mu_0": trace_obj['mu_0'].reshape(1, -1, 3) , #1 chain, 1000 points, D = 3 dim
    "mu_1": trace_obj['mu_1'].reshape(1, -1, 3) , #1 chain, 1000 points, D = 3 dim
    "mu_2": trace_obj['mu_2'].reshape(1, -1, 3) , #1 chain, 1000 points, D = 3 dim
    "sigma_0": trace_obj['sigma_0'].reshape(1, -1, 9), #1 chain, 1000 points, D = 9 dim
    "sigma_1": trace_obj['sigma_1'].reshape(1, -1, 9), #1 chain, 1000 points, D = 9 dim
    "sigma_2": trace_obj['sigma_2'].reshape(1, -1, 9), #1 chain, 1000 points, D = 9 dim
    "p": trace_obj['p'].reshape(1, -1, 3),
    }
    return chain_dict

def create_inference_obj(trace_0, trace_1):
    # Put data for each chain into dictionaries
    chain_0= create_chain_dict(trace_0)
    chain_1= create_chain_dict(trace_1)
    # convert dictionaries into data inference objects
    data_0 = az.from_dict(chain_0 )
    data_1 = az.from_dict(chain_1 )

    #combine data inference objects into a single data inference object w/ multiple chains
    data_combined = az.concat(data_0, data_1, dim= "chain")
    return data_combined


def plot_p(inf_obj, title_str):
    title= title_str + "\n Trace Plots for p"
    myfig = az.plot_trace(inf_obj, compact=False, show= False, var_names=['p'])    
    plt.suptitle(title, fontsize = 20) #only changes the last subplot
    plt.show
1 Like