Im quite new to pymc and would like to use it for image processing. Hence I need it to work with large scale problems. As an initial test I wrote a small program that samples for a large number of independent variables. However, with the NUTS algorithm I get very poor performance as the number of variables increases and the degradation seem to be quite sudden.
My model consists of a prior mu_i ~ U[0,1] and variables that are observed x_ij ~ N(mu_i,sigma=0.2). And the task is to estimate mu_i given a few observations x_ij for each site i.
If I use i=1…4000 and j=1…10, then I get about 100 samples per second. But as I get up to about 7000-8000 the performance drop very much. In particular the NUTS algorithm gives only about 1-4 samples per second for long periods of time.
Is this a limitation of the NUTS algorithm, an isssue or are there any tricks I can use to speed it up?
My machine: Core i7-5820K, 64GB of RAM, Win8.1, Pymc3.6, Python 3.6, Pycharm
import pymc3 as pm import matplotlib.pyplot as plt import scipy.stats as stats import numpy as np def large_scale_mcmc(n_variables, n_obs, n_samples, n_chains=4, sd=0.2): # Generate the observations Mu = stats.uniform.rvs(size=n_variables) X = np.zeros((n_obs, n_variables)) for i in range(n_variables): X[:, i] = sd*stats.norm.rvs(size=n_obs) + Mu[i] # Setup the model and sample with pm.Model() as model: mu = pm.Uniform("mu", 0, 1, shape=n_variables) pm.Normal("x", mu=mu, sd=sd, shape=[n_variables, n_obs], observed=X) step = pm.NUTS() trace_mu = pm.sample(n_samples, chains=n_chains, step=step) return trace_mu, Mu if __name__ == '__main__': n_var = 70 n_observations = 10 n_samp = 100 trace, mu_true = large_scale_mcmc(n_var, n_observations, n_samp, n_chains=4, sd=0.2)