Issues with NUTS performance for large scale problems

Dear all,

Im quite new to pymc and would like to use it for image processing. Hence I need it to work with large scale problems. As an initial test I wrote a small program that samples for a large number of independent variables. However, with the NUTS algorithm I get very poor performance as the number of variables increases and the degradation seem to be quite sudden.

My model consists of a prior mu_i ~ U[0,1] and variables that are observed x_ij ~ N(mu_i,sigma=0.2). And the task is to estimate mu_i given a few observations x_ij for each site i.

If I use i=1…4000 and j=1…10, then I get about 100 samples per second. But as I get up to about 7000-8000 the performance drop very much. In particular the NUTS algorithm gives only about 1-4 samples per second for long periods of time.

Is this a limitation of the NUTS algorithm, an isssue or are there any tricks I can use to speed it up?

My machine: Core i7-5820K, 64GB of RAM, Win8.1, Pymc3.6, Python 3.6, Pycharm

The code:

import pymc3 as pm
import matplotlib.pyplot as plt
import scipy.stats as stats
import numpy as np


def large_scale_mcmc(n_variables, n_obs, n_samples, n_chains=4, sd=0.2):
    # Generate the observations
    Mu = stats.uniform.rvs(size=n_variables)
    X = np.zeros((n_obs, n_variables))
    for i in range(n_variables):
        X[:, i] = sd*stats.norm.rvs(size=n_obs) + Mu[i]

    # Setup the model and sample
    with pm.Model() as model:
        mu = pm.Uniform("mu", 0, 1, shape=n_variables)
        pm.Normal("x", mu=mu, sd=sd, shape=[n_variables, n_obs], observed=X)
        step = pm.NUTS()
        trace_mu = pm.sample(n_samples, chains=n_chains, step=step)

    return trace_mu, Mu

if __name__ == '__main__':
    n_var = 70
    n_observations = 10
    n_samp = 100
    trace, mu_true = large_scale_mcmc(n_var, n_observations, n_samp, n_chains=4, sd=0.2)

Best,
Anders

I think a good thing to try is to set more meaningful mass matrix for NUTS. For example, you can try the custom mass matrix adaptation in Exoplanet: a toolkit for modeling of transit and/or radial velocity observations of exoplanets using PyMC3, or use initialization like advi+adapt_diag, or run once with smaller data, and use the posterior covariance matrix as the mass matrix to initialize the NUTS sampler.