Multiple (uncertain) function observations of the same Gaussian process

Hello,

I am new to PyMC3 and it would be great if someone could provide some guidance/feedback on my implementation below! I am building a package that takes a few (~5 or so, in practice) functions that are assumed to be draws of the same Gaussian process, and attempts to learn some of the hyper parameters that are physically relevant to our model. Right now, my ObservableModel class uses a for loop (which I know is discouraged) to create multiple marginal_likelihood instances, but NUTS becomes quite slow with 2 dimensional inputs, which is where most of these functions will be fit.

Right now, it looks like marginal_likelihood is technically only built for one function y at a time. But I imagine the “best” way to do this would be something like the following:

X = ...    # A N x d array, where N is number of data points and d is dimensions
data = ... # An n x N array, where n is number of functions
noise = 1e-10
with pm.Model() as model:
    # Define RVs
    cov = ...
    gp = pm.gp.Marginal(cov_func=cov)
    gp.marginal_likelihood(
                name,
                X=X,
                y=data,  # <- Not the right shape according to documentation, but still "works"
                noise=noise,
                shape=data.shape
                )

which is how MvNormal works, rather than for func in data: .... This actually does sort of seem to work (and is much faster than the for loop), but I have to cheat a little to get other gp methods, such as conditional, to work right:

with model:
    for func in data:
        gp.conditional(new_name, X_new, given={'y': func, 'X': X, 'noise': noise})

Here’s an example of the output using some toy data using conditional where the dots are observations:


The \bar c variable is just the marginal standard deviation of the gp and the \ell is the length scale.

(1) So my first question is this: is this the correct way to make multiple observations of the same GP without breaking things behind the scenes? If so, can/should the gp methods such as conditional and plot_gp_dist be updated to handle this? I would really like to keep the perks of gps rather than using MvNormal.

Using the above would all be well and good if it weren’t for another aspect of our model, which is a special type of uncertainty in the observations themselves. There is an unknown scaling factor associated with the observations themselves that affect each function differently. This is what is handled by the ExpansionParameterModel class in the package. To describe it I will go into the details of the model a bit.

We are measuring what is generically called an “observable” Z (of which there can be many, and are either functions of one or two variables), whose value can be calculated to some order k in an expansion. The expansion looks like


where the Z_0 and ∆Z_i are known, fixed quantities, and the c_n are found by calculating Z up to order n and inverting the equation. In this case, the functions that are all drawn from the same Gaussian process are the c_n, but they only have such a property if Q is chosen correctly. The variable Q = p/\Lambda_b is called the expansion parameter, and is the ratio of two scales p and \Lambda_b. It is some number between 0 and 1 that is identical across observables Z. We have a pretty good (but not perfect!) idea of what it is, and. If instead of the “correct” Q, we chose the value Q/a, then each of the c_n would be scaled by a^n, making them have different variances and thus be drawn from different, but related GPs.

Right now I do not know how to implement this scaling factor a^n that affects each c_n differently without using a for loop. Currently I loop through both creating a gp with a covariance that is properly scaled, then use marginal_likelihood, which seems to work okay but, again, is quite slow. In a perfect world, I would use the exact same implementation as described in the above “best” case, but where the observed y values are themselves uncertain: c_n * a^n for some RV a.
Below is an example of the predictions with that uncertain scaling parameter learned as described above (\Lambda_b is the uncertain scale):


Again, the dots are the “true” values, and the GP curves (which are no longer forced to go through those points because of the unknown scale) tried to guess a scale that make all the curves look like draws from a single GP.

(2) Does anyone know a better way to implement this that would ease the sampling issues in NUTS? I am not sure if scan in theano is appropriate or some completely different parameterization of my model. Maybe building a custom likelihood would work, but I would still like to harness the power of GPs. Sorry if I did not describe this as well as I could have, I am new to PyMC3.

Thanks!

I think what you are doing is correct. Not sure if it is easy to add it to GPs tho, @bwengals?

The best approach for this problem would be to find a reparameterization of the Cov function to achieve the scaling effect you want.

Thank you for your response, but can you elaborate on what you mean by the reparameterization of the cov function? A simplistic version of my implementation looks something like

with pm.Model() as model:
    # Define RVs
    cov = ...
    for n, c_n in enumerate(data):
        scaled_cov = scale**(2*n) * cov
        gp = pm.gp.Marginal(cov_func=scaled_cov)
        gp.marginal_likelihood(
                    name,
                    X=X,
                    y=c_n,
                    noise=noise,
                    )

which is the only way I can think of to force the GP to treat the c_n functions differently, but involves looping and many gp.Marginal instances. What sort of reparameterization would handle this?

One possible fix, which I don’t know how to implement (but maybe you can help :grinning:), would be something like the following. As I discussed in the original post, the actual observations are the “observables” Z, calculated to some order k in an expansion, which I denote here as Z_k. Now Z_k can be broken up into the contribution from each order, which are the Z_0 and ∆Z_i. Everything discussed so far has a known, observed value. Assuming the overall factor Z_ref is known, we can solve for the products random variables in terms of stuff that is known:

The point is, the stuff on the right-hand side of the equations is observed, so ideally these are the quantities that would be passed into the observed kwarg of some custom distribution for both the c_n and Q. The c_ns and Q are all RVs (or GPs), neither of which are observed individually but only in combination. Maybe this is where some custom likelihood would come in:

The likelihood is really just delta functions or something that are implementing the above equalities and the priors on the c_n could be a single latent gp (maybe with some sort of multiple function hack I described originally) and the prior on Q could be anything, a combination of standard RVs or maybe even its own latent process.

Is this something I could define a custom likelihood or distribution to handle? It would have to observe a product of random variables (or gaussian processes), rather than any individual RV. And would this be a case where Latent gps should be used instead of Marginals? I’ve been reading about defining custom likelihoods and whatnot, but I do not know how to go about doing this.

What I meant by reparameterization is that you unroll the covariance matrix into a big one:

for n, c_n in enumerate(data):
    scaling.append([scale**(2*n)])
    data.append(c_n)
data_flat=np.asarray(data_flat) # 1d
scaling=np.diag(scaling)
with pm.Model() as model:
    # Define RVs
    cov = ...
    # scaling mat
    cov2 = tt.slinalg.kron(scaling, cov(X))
    # likelihood
    obs = pm.MvNormal('obs', mu=np.zeros(data_flat.shape), cov=cov2, observed=data_flat)

You can not use the GP API but in principle this should work.
[Edit] You can still use the GP api by wrap it in a custom Covariance, see answer below.

RE your first question, yes, I am nearly certain that using a multicolumn y will correctly handle GPs with multiple observations. Your result looks correct to me. The reason I say nearly is that I can’t think of anything in the code that would not facilitate this, but multi-observed GPs have not been tested. This is the only reason why the spec says the shape of y has to be (n, ). So I think it’s really fantastic that it works for you (it looks great to me)! I’m really excited to see PyMC3 GPs being used in your project. You are more than welcome to submit a PR to pymc3 to add this or verify that the implementation for multi-output GPs is correct, and I’d be more than happy to either collaborate or assist you with it.

plot_gp_dist would certainly benefit from some improvements. I wrote it as a quick convenience function for plots of 1D GPs, mostly for use making the documentation. If folks find it useful it should be fixed up.

Maybe setting the shape arg will help get conditional to work?

In short yes they definitely can and should :+1:

As far as this covariance structure, you can do what @junpenglao suggested, and it’s not too difficult to wrap it in a custom Covariance, like is shown here. Then you can continue to use the GP API.

Can’t be sure what might be causing NUTS to slow down when you move to 2D inputs. Evaluating logp for GPs is quite slow due to the $O(n^3)$ scaling, and NUTS does rely on many logp evaluations. So some slowness here might just be the cost of doing business. For quick-and-dirty, find_MAP will find the mode pretty reliably if it is given reasonable starting points.

@bwengals It is good to know that this almost certainly works. I would happily submit a PR and/or check that all is well with the implementation, but as I have not done these things before, help would be appreciated! :slight_smile: Personally, I like plot_gp_dist and have found it quite useful for quickly checking results. As far as I can tell, the shape arg in conditional doesn’t actually get used when passed through the given dictionary (_get_given_vals seems to ignore it). The only way I’ve gotten conditional to work curve by curve in the “multi-observation” case is by passing y, X, and noise so that the internals are overridden for each y curve that I want to predict. Is this what you meant by shape?

Alright I have made my first pass at implementing both of your suggestions, which I’ll put here just in case it helps someone. Right now I ObserableModel class is set up to use the multiple observed y if the functions need not be scaled separately, and use the Kronecker product idea suggested by @junpenglao if they are scaled in the way described above.

if self.expansion_parameter is None:
    # Create a multi-observed GP
    self.gp = pm.gp.Marginal(cov_func=cov)
    obs = self.gp.marginal_likelihood(
            'obs',
            X=self.X,
            y=self.data,
            noise=self.noise,
            shape=self.data.shape  # <- Might be unnecessary
            )
else:
    # Expand cov as kronecker product and learn all data at once
    scale = self.expansion_parameter.scale
    scaled_cov = ScaledCov(cov, scale, self.index_list)
    # Repeat self.X so it is same shape as data_flat
    X_concat = np.concatenate(tuple(self.X for n in self.index_list))
    data_flat = self.data.flatten()

    self.gp = pm.gp.Marginal(cov_func=scaled_cov)
    obs = self.gp.marginal_likelihood(
            'obs',
            X=X_concat,
            y=data_flat,
            noise=self.noise,
            )

where I have wrapped the Kronecker implementation in a custom Covariance:

class ScaledCov(pm.gp.cov.Covariance):
    """Create a big kronecker product of scale and cov.

    Parameters
    ----------
    cov    : gp.cov.Covariance object
             The covariance that will be scaled
    scale  : RV
             A random variable that will scale cov differently
             in each block of the kronecker product
    powers : list
             The powers of scale that will multiply cov in each
             block of the kronecker product
    """

    def __init__(self, cov, scale, powers):
        super(ScaledCov, self).__init__(input_dim=cov.input_dim,
                                        active_dims=cov.active_dims)
        self.cov = cov
        self.powers = powers
        self.scales = [scale**(-2*n) for n in powers]
        self.scales_diag = tt.nlinalg.diag(self.scales)

    def diag(self, X):
        """[scales[0] * cov.diag(X), scales[1] * cov.diag(X), ... ].ravel()"""
        X_unique = self.unique_domain(X)
        return tt.outer(self.scales, self.cov.diag(X_unique)).ravel()

    def full(self, X, Xs=None):
        X_unique = self.unique_domain(X)
        Xs_unique = None
        if Xs is not None:
            Xs_unique = self.unique_domain(Xs)
        covfull = self.cov(X_unique, Xs_unique)
        return tt.slinalg.kron(self.scales_diag, covfull)

    def unique_domain(self, X):
        unique_length = len(X)//len(self.powers)
        return X[:unique_length]

This works and I can use the GP api with it (though I have to play with indexing the trace to separate out which curve is which, unless anyone has any better ideas). Conditional works as well, but I have to paste the domain together the appropriate amount of times

# after trace
with test_observable as model:
    # conditional
    Xnew_concat = np.concatenate(tuple(Xnew for n in powers))
    cn = model.gp.conditional("cn", Xnew=Xnew_concat)
    pred_samples = pm.sample_ppc(trace, vars=[cn], samples=50)  # All curves, must disentangle later on.

Unfortunately, finding the trace is actually slower than using a for loop that creates a new gp.Marginal instance for each curve. The Kronecker method for some grid goes at ~10 it/s, while the for loop goes at ~15-20 it/s, all other things held equal. Apparently, the growth in size of the big Kronecker covariance matrix outweighs the benefits from removing the for loop, at least for my test case. Can we take advantage of the fact that many of the Kronecker covariance matrix entries are zero to speed things up?

It is likely that it’s actually slower building a large covariance matrix, another approach is to create a coregional kernel as in GPy and GPflow: https://github.com/GPflow/GPflow/blob/69364c2b79e1e6dd0fe7419377a044c0bdedb4f5/gpflow/kernels.py#L664
If I understand correctly doing this way does not directly build the complete covariance matrix but just index to it?

This looks like something great to implement, but unless I’m missing something obvious is a bit more than I can do with my knowledge of PyMC3 at the moment. If multi-observed GPs are ultimately added, this coregional kernel would definitely extend what is possible in the current version. The “Intrinsic Coregionalization Model” in this paper looks like what I am doing here. I imagine that utilizing properties of Kronecker products for, e.g., inversion, would speed things up considerably.

If some more guidance were provided I would try to help out with this stuff!

In addition to the link @junpenglao posted, also check here. I’m thinking this looks exactly like what you want. I’m interested in helping add this model, clearly it’s useful. Not inverting the full cov matrix will speed things up a lot.

@jordan-melendez I think what you are doing with building conditional in a for loop isn’t necessarily a bad thing, since you only have 5 output variables. Does doing this give you the correct model? You’re right that you can swap in different uncertainties by providing sigma in given.

I think you’re right, messing with shape isn’t what youre after. shape is eventually passed in here.

It looks like we can add just a Coregion kernel, and be able to use GPflow as a pretty close guide. I’m not sure though about this line. Will need to check if that is a limitation of their implementation or of the Coregionalized GP model.

The for loop in conditional is not a problem for me, and seems to give the correct results. The issue is really with the marginal_likelihood when getting the trace.

In regards to the 1D Coregion kernel of GPflow, it looks like Gpy has no restrictions on their version here. But it looks like one would have to implement some sort of mixed noise likelihood or maybe a matrix normal distribution to be able to make use of the Kronecker product right? Otherwise we’re back to inverting the whole thing every time, I think.

So I have hacked together my own matrix-valued normal distribution, MatNormal, based on the MvNormal for this purpose. With very preliminary testing, it seems to work quickly and correctly! I’ll probably get a MWE up here for testing at some point, but I just wanted to get it out for feedback or corrections.

Here it is.

class MatNormal(Continuous):
    R"""
    Matrix-valued normal log-likelihood.

    Distribution for the qxp matrix Y
    Must take two precision-like matrices (cov, chol, or tau):
        left (or column) qxq matrix defines variance within columns
            denoted (lcov, lchol, or ltau) 
        right (or row) pxp matrix defines variance within rows
            denoted (rcov, rchol, or rtau)
    """

    def __init__(self, mu=0, rcov=None, rchol=None, rtau=None,
                 lcov=None, lchol=None, ltau=None, *args, **kwargs):

        self.setup_matrices(rcov, rchol, rtau, lcov, lchol, ltau)

        shape = kwargs.pop('shape', None)
        assert len(shape) == 2, "only 2d tuple inputs work right now: qxp"
        self.shape = shape

        super(MatNormal, self).__init__(shape=shape, *args, **kwargs)

        self.mu = tt.as_tensor_variable(mu)

        self.mean = self.median = self.mode = self.mu

        self.solve_lower = tt.slinalg.Solve(A_structure="lower_triangular")
        self.solve_upper = tt.slinalg.Solve(A_structure="upper_triangular")

    def setup_matrices(self, rcov, rchol, rtau, lcov, lchol, ltau):
        # Step methods and advi do not catch LinAlgErrors at the
        # moment. We work around that by using a cholesky op
        # that returns a nan as first entry instead of raising
        # an error.
        cholesky = Cholesky(nofail=True, lower=True)

        # Right (or row) matrices
        if len([i for i in [rtau, rcov, rchol] if i is not None]) != 1:
            raise ValueError('Incompatible parameterization. '
                             'Specify exactly one of rtau, rcov, '
                             'or rchol.')
        if rcov is not None:
            self.p = rcov.shape[0]  # How many points along vector
            self._rcov_type = 'cov'
            rcov = tt.as_tensor_variable(rcov)
            if rcov.ndim != 2:
                raise ValueError('rcov must be two dimensional.')
            self.rchol_cov = cholesky(rcov)
            self.rcov = rcov
            # self._n = self.rcov.shape[-1]
        elif rtau is not None:
            raise ValueError('rtau not supported at this time')
            self.p = rtau.shape[0]
            self._rcov_type = 'tau'
            rtau = tt.as_tensor_variable(rtau)
            if rtau.ndim != 2:
                raise ValueError('rtau must be two dimensional.')
            self.rchol_tau = cholesky(rtau)
            self.rtau = rtau
            # self._n = self.rtau.shape[-1]
        else:
            self.p = rchol.shape[0]
            self._rcov_type = 'chol'
            if rchol.ndim != 2:
                raise ValueError('rchol must be two dimensional.')
            self.rchol_cov = tt.as_tensor_variable(rchol)
            # self._n = self.rchol_cov.shape[-1]

        # Left (or column) matrices
        if len([i for i in [ltau, lcov, lchol] if i is not None]) != 1:
            raise ValueError('Incompatible parameterization. '
                             'Specify exactly one of ltau, lcov, '
                             'or lchol.')
        if lcov is not None:
            self.q = lcov.shape[0]
            self._lcov_type = 'cov'
            lcov = tt.as_tensor_variable(lcov)
            if lcov.ndim != 2:
                raise ValueError('lcov must be two dimensional.')
            self.lchol_cov = cholesky(lcov)
            self.lcov = lcov
            # self._n = self.lcov.shape[-1]
        elif ltau is not None:
            raise ValueError('ltau not supported at this time')
            self.q = ltau.shape[0]
            self._lcov_type = 'tau'
            ltau = tt.as_tensor_variable(ltau)
            if ltau.ndim != 2:
                raise ValueError('ltau must be two dimensional.')
            self.lchol_tau = cholesky(ltau)
            self.ltau = ltau
            # self._n = self.ltau.shape[-1]
        else:
            self.q = lchol.shape[0]
            self._lcov_type = 'chol'
            if lchol.ndim != 2:
                raise ValueError('lchol must be two dimensional.')
            self.lchol_cov = tt.as_tensor_variable(lchol)
            # self._n = self.lchol_cov.shape[-1]

    def random(self, point=None, size=None):
        if size is None:
            size = list(self.shape)

        mu, rchol, lchol = draw_values([self.mu, self.rchol_cov, self.lchol_cov], point=point)
        standard_normal = np.random.standard_normal(size)

        return mu + lchol @ standard_normal @ rchol.T

    def _trquaddist(self, value):
        """Compute Tr[rcov^-1 (x - mu).T @ lcov^-1 @ (x - mu)] and
        the logdet of rcov and lcov."""
        mu = self.mu

        delta = value - mu

        lchol_cov = self.lchol_cov
        rchol_cov = self.rchol_cov

        rdiag = tt.nlinalg.diag(rchol_cov)
        ldiag = tt.nlinalg.diag(lchol_cov)
        # Check if the covariance matrix is positive definite.
        rok = tt.all(rdiag > 0)
        lok = tt.all(ldiag > 0)
        ok = rok and lok

        # If not, replace the diagonal. We return -inf later, but
        # need to prevent solve_lower from throwing an exception.
        rchol_cov = tt.switch(rok, rchol_cov, 1)
        lchol_cov = tt.switch(lok, lchol_cov, 1)

        # Find exponent piece by piece
        right_quaddist = self.solve_lower(lchol_cov, delta)
        quaddist = tt.nlinalg.matrix_dot(right_quaddist.T, right_quaddist)
        quaddist = self.solve_lower(rchol_cov, quaddist)
        quaddist = self.solve_upper(rchol_cov.T, quaddist)
        trquaddist = tt.nlinalg.trace(quaddist)

        half_rlogdet = tt.sum(tt.log(rdiag))  # logdet(M) = 2*Tr(log(L))
        half_llogdet = tt.sum(tt.log(ldiag))

        return trquaddist, half_rlogdet, half_llogdet, ok

    def logp(self, value):
        trquaddist, half_rlogdet, half_llogdet, ok = self._trquaddist(value)
        q = self.q
        p = self.p
        norm = - 0.5 * q * p * pm.floatX(np.log(2 * np.pi))
        return bound(
                norm - 0.5 * trquaddist - q * half_rlogdet - p * half_llogdet,
                ok)
3 Likes

Here is a notebook with some basic testing of MatNormal against MvNormal with a for loop. I’ve noticed some biased results in some other (more complicated) files I’ve been testing with, but was unable to reproduce them in this notebook, though I did try. It may be a mistake in those files that I have overlooked.

Anyways, the results are promising!

1 Like

That looks fantastic!! I think it would be awesome if you could submit this as a PR! I think you can put this into pymc3/distributions/multivariate.py. Yes, everything looks perfect in your notebook. Githubs a better place to go over code, so we can go over this more there. I’ve been wanting to add this for a while, but haven’t gotten around to it. With a MatrixNormal distribution, it should be pretty straightforward to add a few more GP implementations, maybe called Marginal2D and Latent2D, including the coregional model. Super awesome!

Done. Thanks everyone for all of your help! Should this be marked as resolved now?

Also, are the addition of features such as Marginal2D or Coregion being discussed anywhere? My research would benefit those features, and I would like to be kept in the loop, and/or would be open to contributing to get the ball rolling on this front.

1 Like

Thanks for the PR! Also marked your reply above as the solution.
I don’t think there is much discussion elsewhere on Marginal2D or Coregion. Would you like to open a new post?