Speeding up MatrixNormal for missing data imputation (or more generally)

I have a dataset with quite a bit of structure between both rows (samples) and columns (biological measurements), and I’d like to use a MatrixNormal to approximate that structure to perform Bayesian imputation of the missing values.

A very similar (public) dataset is the Arcene mass-spec data

np.random.seed(12345)

# using the arcene dataset as a substrate
# http://archive.ics.uci.edu/ml/datasets/Arcene
Xraw = pd.read_csv('arcene_train.data.gz', sep=' ', header=None).T
Xraw = Xraw.iloc[:10000,:]  # trailing line
Xraw.index = ['sample_%d' % i for i in Xraw.index]
Xraw.columns = ['peptide_%d' % i for i in Xraw.columns]

desired_prop_na = 0.025
n_na = int(desired_prop_na * np.prod(Xraw.shape)) + 1
na_idx = (np.random.choice(Xraw.shape[0], n_na, replace=True), np.random.choice(Xraw.shape[1], n_na, replace=True))

X = Xraw.copy().values
X[na_idx] = np.nan

My dataset is about the same size (~100 samples, ~10,000 measurements), so the correlation of the measurements is empirically low-rank; so I’d like to take a diagonal-plus-low-rank approximation:

\Sigma \sim \Psi + LL^T

with \Psi a dim-n (m) matrix, and L a n \times k (m \times k) low-rank matrix. So there are (m+n+4)k (accounting for row and column means) parameters, with mn observations; so plenty of degrees of freedom in this model for say k=10.

I’m running into the issue that the sampling here is insanely slow; about 1 minute per ADVI iteration:

with pm.Model() as matnorm_dpl:
        row_n = X.shape[0]
        col_n = X.shape[1]
        row_cov = diag_plus_lowrank(row_n, k_row, 'row_cov')
        col_cov = diag_plus_lowrank(col_n, k_col, 'col_cov')
        row_means = pm.Normal('row_means', 0, 5., shape=(row_n,1))
        col_means = pm.Normal('col_means', 0, 5., shape=(1,col_n))
        mat_mean = pm.Deterministic('M', row_means + tt.zeros((row_n, col_n)) + col_means)
        data = np.ma.masked_invalid(np.log(1 + X))
        lik = pm.MatrixNormal('Xmodel', mat_mean, 
                              rowcov=row_cov, 
                              colcov=col_cov, shape=data.shape,
                              observed=data)

        ### 60s/it = 2 months to fit
        samples = pm.ADVI().fit(100000).sample(2000)

(helper functions for completeness)

def expand_packed_block_triangular(n, k, packed, diag=None, mtype='theano'):
    # like expand_packed_triangular, but with n > k.
    assert mtype in {'theano', 'numpy'}
    assert n >= k
    def set_(M, i_, v_):
        if mtype == 'theano':
            return tt.set_subtensor(M[i_], v_)
        M[i_] = v_
        return M
    out = tt.zeros((n, k), dtype=float) if mtype == 'theano' else np.zeros((n,k), dtype=float)
    if diag is None:
        idxs = np.tril_indices(n, m=k)
        out = set_(out, idxs, packed)
    else:
        idxs = np.tril_indices(n, k=-1, m=k)
        out = set_(out, idxs, packed)
        idxs = (np.arange(k), np.arange(k))
        out = set_(out, idxs, diag)
    return out


def diag_plus_lowrank(n, k, prefix='dpl', s_diag=5., s_lr=1.):
    """\
    Create a diagonal-plus-lowrank approximation to an n x k covariance matrix
    
    """
    psi = pm.HalfNormal('%s_diag' % prefix, s_diag, shape=(n,))
    d = pm.HalfNormal('%s_lr_diag' % prefix, s_lr, shape=(k,))
    n_od = int(k*n - k*(k-1)/2 - k)
    od = pm.Normal('%s_lr_od' % prefix, s_lr, shape=(n_od,))
    L = pm.Deterministic('%s_lr_L' % prefix, expand_packed_block_triangular(n, k, od, d))
    P = pm.Deterministic('%s_Psi' % prefix, tt.diag(psi))
    return pm.Deterministic('%s_cov' % prefix, P + tt.dot(L, tt.transpose(L)))

by contrast, the dumb iterated-SVD approach fits in a few seconds:

def svd_impute(X, k=75, iters=200):
    na_idx = np.where(np.isnan(X))
    X[na_idx] = np.random.normal(size=len(na_idx[0]))
    for j in range(iters):
        X_svd = sp.sparse.linalg.svds(X, k=k)
        Xapp = matprod([X_svd[0][:, :k], np.diag(X_svd[1][:k]), X_svd[2][:k, :]])
        delta = np.sum((X[na_idx] - Xapp[na_idx])**2)
        X[na_idx] = Xapp[na_idx]
        print('iter %d: %.2e' % (j, delta))
        if delta < 1e-3:
            break
    return X

A ~600,000x slowdown seems like a high price to pay for obtaining posteriors. Is there a good way to speed this up?