Definition of length of a MultiTrace

In base.py I see the following in the definition for MultiTrace:

def __len__(self):
    chain = self.chains[-1]
    return len(self._straces[chain])

Could someone explain why this is the length of the trace? I thought that if I iterated over a MultiTrace using only an integer index, I would get back all of the points in all of the chains, but this length suggests that the length is defined as the length of a single, arbitrarily-chosen chain. But maybe I’m missing something here?

I guess this piece of the MultiTrace docstring explains it:

For any methods that require a single trace (e.g., taking the length
of the MultiTrace instance, which returns the number of draws), the
trace with the highest chain number is always used.

…but this seems to almost directly contradict

1. Indexing with a variable or variable name (str) returns all
   values for that variable, combining values for all chains.

   >>> trace[varname]

and then I see this:

2. Indexing with an integer returns a dictionary with values for
   each variable at the given index (corresponding to a single
   sampling iteration).

…which is silent on the question of whether the indexing ranges over all of the chains or only a single one (and if so, which one).

If someone could clarify this for me, I will try to provide a pull request to clarify the documentation.

Also, what’s the way to get the total number of points, across all the chains? I think one could do len(list(trace.points)) but that’s a very expensive query. Is there something more efficient?

The len and indexing issue came up before as well, but at the end we kept the current implementation. For example see a discussion here:

That would be my intuition as well, but currently iterating would just return sample from one chain. It is actually used to be a bug in sample_ppc that we only using the posterior samples from one chain… (that’s why now internally the iteration iterate over chains and samples).

1 is correct, trace[varname] would indeed return all sample, see discussion in above link

2 is incorrect, it only index to the 1st chain.

and yes, a PR to improve the doc would be great!

As for indexing to all points across all chains, you can use trace._straces:

Since trace._straces looks like an “internal-only” function, what about adding a method like all_points which would just be a wrapper around points that would iterate through all the chains? That would look more like an interface function to the user. If that sounds sensible, I could provide a PR for that.

Actually, minor correction: I believe that integer indices into traces index to the last chain, not the first chain. __getattr__ for integers falls through to the point method with chains=None, which means -1 is used.

Never mind, my bad – trace.points() will give me the iterator I want.

I’d still like a way to cheaply get the number of points in the trace and the number of points in a chain. I’m too stupid to do anything but count them, which is very slow!

2 Likes