Feature Request: handle label switching in summary. Some ideas

It would be great if there were a way to handle label switching in traces. I’ve come up against this with a latent space model where the dimensions are different across traces. In this case I’d been worried about rotation and shearing, but it appears that the dimensions are consistent across chains, they’re just mixed and matched. I’ve seen other discussion on the Discourse where people have had similar issues, and I have some ideas for how to handle it.

By exploring the traces themselves you can figure out which dimensions should align, but I haven’t figured out how to change the underlying data in the trace object. For example, the following does not work. Here, I use a 3000 sample trace for a 3 dimension latent variable (size 500x3). I try to overwrite the data in the trace object such that the sampled values for the 0th and 1st dimensions are swapped. When I then try to retrieve those values, they are unaltered.

trace2 = copy.copy(trace)
trace2.z[3000:5999,:,0]=trace.z[3000:5999,:,1]
trace2.z[3000:5999,:,1]=trace.z[3000:5999,:,0]

Apologies if there’s already an easy way to do this: dir(trace) lists a method to remove values, but not to change values. This would be one possible workaround, because after switching the dimensions in the trace itself, the diagnostics would all work properly.

Another possibility would be to allow pm.summary() and other diagnostics to take an array instead of only accepting a multitrace object.

For example, in diagnostics.py on line 172, the gelmin_rubin function just converts it into a numpy array.
x = np.array(mtrace.get_values(var, combine=False))
The function could easily just take x directly, perhaps with a warning to make sure you know what you’re doing if you don’t use a multitrace. For my case, I just copied the functions from the pymc3 source and used numpy arrays to calculate the diagnostics, but it seems like a small change that would make pymc3 more useful functional for models with latent variables.

Using arviz is another solutions here, but I had trouble figuring out the dimensions for the data structures, and how to keep chains separate. So the easiest solution might just to add a few lines to the documentation of the arviz package (so I’ll also tag @RavinKumar )

2 Likes

I think a helper function to deal with label switching would indeed be helpful. Something like:

trace_new = pm.synclatent(trace_old, varnames=[...], mapping={0:1,...})

where mapping kwarg is optional.

The challenge is then to find this mapping automatically. Maybe something like a kmean clustering to cluster everything together and use the label from the kmean?

1 Like

I think that might be overkill. Here was my solution, and I think it’s pretty generalizeable:

z0 = np.percentile(trace.get_values('z',chains=[0]),50,axis=0)
z1 = np.percentile(trace.get_values('z',chains=[1]),50,axis=0)
z= pd.DataFrame( {'z00':z0[:,0],
                  'z01':z0[:,1],
                  'z02':z0[:,2],
                  'z10':z1[:,0],
                  'z11':z1[:,1],
                  'z12':z1[:,2],})
z.corr()

Here’s the output:

z00 z01 z02 z10 z11 z12
z00 1.000000 -0.037718 0.273659 -0.047861 0.999247 0.309811
z01 -0.037718 1.000000 0.196097 0.999248 -0.019276 0.157059
z02 0.273659 0.196097 1.000000 0.227347 0.246043 0.998589
z10 -0.047861 0.999248 0.227347 1.000000 -0.030644 0.187845
z11 0.999247 -0.019276 0.246043 -0.030644 1.000000 0.281592
z12 0.309811 0.157059 0.998589 0.187845 0.281592 1.000000

Here, it was very clear that dimension 0 should be switched with dimension 1 in the second trace because the two dimensions were correlated with one another at their median. If we just set a threshold for correlation, maybe check for consistency across a handful of percentiles, then use that to remap. If it fails, then just say that there was no obvious remapping, and asks the user to supply one. If none match, then that’s also a sign that it didn’t converge.

I see, yeah that looks like a pretty good solution, but it doesnt quite work with

  1. mode switching within chain
  2. multivariate RVs

right?

It works with 2 if you’re referring to what I think you’re referring to. The above example is multivariate. z is 500x3. If you mean adding another dimension, then I think it would still work. You would correlate within each dimension and then aggregate over the correlation coefficients in some way? Maybe they’d all have to be correlated in order to be able to swap the variables?

Didn’t think about mid-chain. That’s definitely a harder problem, because you have to decide where to cut and I imagine there’d be a fuzzy boundary as the sampler wonders between.

If they’re jumping within the chain, though, then that’s a bigger problem, because you would never really be able to tell whether two dimensions were distinct, or whether they were both the average of some other two latent variables that the sampler was stuck between. Could you two distinguish the following two cases: 1) two very different latent dimensions jump back and forth frequently and 2) two latent dimensions have similar means and higher variance.

1 Like

Since we’re assuming one mode per variable, this seems very HMM to me. Use a low within-trace switch probability, and a higher cross-trace switch probability, and treat the emission distribution as normal wiht unknown mean and variance.

But (implementing in pymc3 and) using an HMM to solve an identifiability issue seems like total overkill. Is there a way to parametrize your space or loadings so that the posterior is proper?

I’m do not think there is a way to change the parameterization, That said, even if there were for this model, I don’t think there necessarily is for every model with latent variables that has label switching, and I found the problem pop up a handful of times on the Discourse when I was looking for a solution for my own case.

A full HMM may be overkill, but if we can think of a heuristic to sync some proportion automatically, then it might be worth it. Maybe just a warning that the automatic function will only work if there is no within-chain switching. If there is within chain switching it will probably look non-converged so you’d at least know.

But more important is some way to alter the multitrace object to correct by hand. Or to allow the diagnostics to accept arrays, but that would probably require more changes.

1 Like

Another place something similar is coming up, for which it would be nice to have some way to edit the traces or feed an array into the diagnostics:

If you have to stop a trace prematurely, then each chain will be of different length, and the summary function won’t work.