Multiple (uncertain) function observations of the same Gaussian process

So I have hacked together my own matrix-valued normal distribution, MatNormal, based on the MvNormal for this purpose. With very preliminary testing, it seems to work quickly and correctly! I’ll probably get a MWE up here for testing at some point, but I just wanted to get it out for feedback or corrections.

Here it is.

class MatNormal(Continuous):
    R"""
    Matrix-valued normal log-likelihood.

    Distribution for the qxp matrix Y
    Must take two precision-like matrices (cov, chol, or tau):
        left (or column) qxq matrix defines variance within columns
            denoted (lcov, lchol, or ltau) 
        right (or row) pxp matrix defines variance within rows
            denoted (rcov, rchol, or rtau)
    """

    def __init__(self, mu=0, rcov=None, rchol=None, rtau=None,
                 lcov=None, lchol=None, ltau=None, *args, **kwargs):

        self.setup_matrices(rcov, rchol, rtau, lcov, lchol, ltau)

        shape = kwargs.pop('shape', None)
        assert len(shape) == 2, "only 2d tuple inputs work right now: qxp"
        self.shape = shape

        super(MatNormal, self).__init__(shape=shape, *args, **kwargs)

        self.mu = tt.as_tensor_variable(mu)

        self.mean = self.median = self.mode = self.mu

        self.solve_lower = tt.slinalg.Solve(A_structure="lower_triangular")
        self.solve_upper = tt.slinalg.Solve(A_structure="upper_triangular")

    def setup_matrices(self, rcov, rchol, rtau, lcov, lchol, ltau):
        # Step methods and advi do not catch LinAlgErrors at the
        # moment. We work around that by using a cholesky op
        # that returns a nan as first entry instead of raising
        # an error.
        cholesky = Cholesky(nofail=True, lower=True)

        # Right (or row) matrices
        if len([i for i in [rtau, rcov, rchol] if i is not None]) != 1:
            raise ValueError('Incompatible parameterization. '
                             'Specify exactly one of rtau, rcov, '
                             'or rchol.')
        if rcov is not None:
            self.p = rcov.shape[0]  # How many points along vector
            self._rcov_type = 'cov'
            rcov = tt.as_tensor_variable(rcov)
            if rcov.ndim != 2:
                raise ValueError('rcov must be two dimensional.')
            self.rchol_cov = cholesky(rcov)
            self.rcov = rcov
            # self._n = self.rcov.shape[-1]
        elif rtau is not None:
            raise ValueError('rtau not supported at this time')
            self.p = rtau.shape[0]
            self._rcov_type = 'tau'
            rtau = tt.as_tensor_variable(rtau)
            if rtau.ndim != 2:
                raise ValueError('rtau must be two dimensional.')
            self.rchol_tau = cholesky(rtau)
            self.rtau = rtau
            # self._n = self.rtau.shape[-1]
        else:
            self.p = rchol.shape[0]
            self._rcov_type = 'chol'
            if rchol.ndim != 2:
                raise ValueError('rchol must be two dimensional.')
            self.rchol_cov = tt.as_tensor_variable(rchol)
            # self._n = self.rchol_cov.shape[-1]

        # Left (or column) matrices
        if len([i for i in [ltau, lcov, lchol] if i is not None]) != 1:
            raise ValueError('Incompatible parameterization. '
                             'Specify exactly one of ltau, lcov, '
                             'or lchol.')
        if lcov is not None:
            self.q = lcov.shape[0]
            self._lcov_type = 'cov'
            lcov = tt.as_tensor_variable(lcov)
            if lcov.ndim != 2:
                raise ValueError('lcov must be two dimensional.')
            self.lchol_cov = cholesky(lcov)
            self.lcov = lcov
            # self._n = self.lcov.shape[-1]
        elif ltau is not None:
            raise ValueError('ltau not supported at this time')
            self.q = ltau.shape[0]
            self._lcov_type = 'tau'
            ltau = tt.as_tensor_variable(ltau)
            if ltau.ndim != 2:
                raise ValueError('ltau must be two dimensional.')
            self.lchol_tau = cholesky(ltau)
            self.ltau = ltau
            # self._n = self.ltau.shape[-1]
        else:
            self.q = lchol.shape[0]
            self._lcov_type = 'chol'
            if lchol.ndim != 2:
                raise ValueError('lchol must be two dimensional.')
            self.lchol_cov = tt.as_tensor_variable(lchol)
            # self._n = self.lchol_cov.shape[-1]

    def random(self, point=None, size=None):
        if size is None:
            size = list(self.shape)

        mu, rchol, lchol = draw_values([self.mu, self.rchol_cov, self.lchol_cov], point=point)
        standard_normal = np.random.standard_normal(size)

        return mu + lchol @ standard_normal @ rchol.T

    def _trquaddist(self, value):
        """Compute Tr[rcov^-1 (x - mu).T @ lcov^-1 @ (x - mu)] and
        the logdet of rcov and lcov."""
        mu = self.mu

        delta = value - mu

        lchol_cov = self.lchol_cov
        rchol_cov = self.rchol_cov

        rdiag = tt.nlinalg.diag(rchol_cov)
        ldiag = tt.nlinalg.diag(lchol_cov)
        # Check if the covariance matrix is positive definite.
        rok = tt.all(rdiag > 0)
        lok = tt.all(ldiag > 0)
        ok = rok and lok

        # If not, replace the diagonal. We return -inf later, but
        # need to prevent solve_lower from throwing an exception.
        rchol_cov = tt.switch(rok, rchol_cov, 1)
        lchol_cov = tt.switch(lok, lchol_cov, 1)

        # Find exponent piece by piece
        right_quaddist = self.solve_lower(lchol_cov, delta)
        quaddist = tt.nlinalg.matrix_dot(right_quaddist.T, right_quaddist)
        quaddist = self.solve_lower(rchol_cov, quaddist)
        quaddist = self.solve_upper(rchol_cov.T, quaddist)
        trquaddist = tt.nlinalg.trace(quaddist)

        half_rlogdet = tt.sum(tt.log(rdiag))  # logdet(M) = 2*Tr(log(L))
        half_llogdet = tt.sum(tt.log(ldiag))

        return trquaddist, half_rlogdet, half_llogdet, ok

    def logp(self, value):
        trquaddist, half_rlogdet, half_llogdet, ok = self._trquaddist(value)
        q = self.q
        p = self.p
        norm = - 0.5 * q * p * pm.floatX(np.log(2 * np.pi))
        return bound(
                norm - 0.5 * trquaddist - q * half_rlogdet - p * half_llogdet,
                ok)
3 Likes