Is it possible to speed up PyMC sampling?


I’m running a model on a fairly large dataset. I’ve just taken a job with unlimited use of the google cloud platform and thought, if I chose a higher performing CPU and higher memory, the sampling would noticeably speed up.

I haven’t seen that. The data I’m sampling has around 500,000 observations. It’s a time series based on daily data for five years with multiple items forecasted per year.

I’m currently using a 56core, 112GB RAM setup. No GPU as I’ve actually never ran PyMC with numpyro/jax on a GPU. Will the GPU help?

Do any of you have suggestions to try?


For Reference, here is my model and versions of programs:

with pm.Model(coords = coords) as model:
    # item_idx = pm.Data('item_idx', items, dims = "obs_id", mutable = False)
    k = pm.Normal('k', 0, 1)
    m = pm.Normal('m', 0, 5)
    delta = pm.Laplace('delta', 0, 0.1, shape = n_changepoints)
    growth = k +, delta)
    offset = m +, -s * delta)
    trend = growth * t  + offset
#     beta_weekly = pm.Normal('beta_weekly_seasonality', 0, 1, shape = weekly_n_components * 2)
#     seasonality_weekly =, p = 7), beta_weekly)
#     beta_monthly = pm.Normal('beta_monthly_seasonality', 0, 1, shape = monthly_n_components * 2)
#     seasonality_monthly =, p = 30.5), beta_monthly)
#     beta_yearly = pm.Normal('beta_yearly_seasonality', 0, 1, shape = yearly_n_components * 2)
#     seasonality_yearly =, p = 365.25), beta_yearly)
    error = pm.HalfCauchy('sigma', .5)
              observed = train_y)
    trace = pymc.sampling_jax.sample_numpyro_nuts(tune=1000, chains = 4)

In general, MCMC will use a CPU/core per chain, so cranking up the number of cores won’t improve speed. The results of sampling may be held in memory during sampling, so memory is necessary (and memory demands will be somewhat greater for large models) but not directly related to speed. Core speed (CPU or GPU) will impact sampling speed as will the geometry of your model (i.e., using difficult-to-sample priors, etc.). GPU can definitely help. See here for an example.

1 Like

So if I’m reading the git link correctly, all I have to do is the following and PyMC/JAX will use the GPU?

assert platform in ["cpu", "gpu"]

if platform == "cpu":
    # Disable GPU
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
1 Like

I will let someone with more jax/gpu experience (@twiecki @ricardoV94 ?) weigh in on the implementation details. As we have discussed, the user guide for v4/jax/gpu is not yet ready. But hopefully we can get you sorted out.