Dear all,
Im quite new to pymc and would like to use it for image processing. Hence I need it to work with large scale problems. As an initial test I wrote a small program that samples for a large number of independent variables. However, with the NUTS algorithm I get very poor performance as the number of variables increases and the degradation seem to be quite sudden.
My model consists of a prior mu_i ~ U[0,1] and variables that are observed x_ij ~ N(mu_i,sigma=0.2). And the task is to estimate mu_i given a few observations x_ij for each site i.
If I use i=1…4000 and j=1…10, then I get about 100 samples per second. But as I get up to about 7000-8000 the performance drop very much. In particular the NUTS algorithm gives only about 1-4 samples per second for long periods of time.
Is this a limitation of the NUTS algorithm, an isssue or are there any tricks I can use to speed it up?
My machine: Core i7-5820K, 64GB of RAM, Win8.1, Pymc3.6, Python 3.6, Pycharm
The code:
import pymc3 as pm
import matplotlib.pyplot as plt
import scipy.stats as stats
import numpy as np
def large_scale_mcmc(n_variables, n_obs, n_samples, n_chains=4, sd=0.2):
# Generate the observations
Mu = stats.uniform.rvs(size=n_variables)
X = np.zeros((n_obs, n_variables))
for i in range(n_variables):
X[:, i] = sd*stats.norm.rvs(size=n_obs) + Mu[i]
# Setup the model and sample
with pm.Model() as model:
mu = pm.Uniform("mu", 0, 1, shape=n_variables)
pm.Normal("x", mu=mu, sd=sd, shape=[n_variables, n_obs], observed=X)
step = pm.NUTS()
trace_mu = pm.sample(n_samples, chains=n_chains, step=step)
return trace_mu, Mu
if __name__ == '__main__':
n_var = 70
n_observations = 10
n_samp = 100
trace, mu_true = large_scale_mcmc(n_var, n_observations, n_samp, n_chains=4, sd=0.2)
Best,
Anders