Memory spike at the end of the MCMC sampling

Hi there

Our workplace is fitting relatively simple GLM models to datasets with approx 1 million to 10 million observations.

The models themselves aren’t overly complex – either simple linear regression or GLMs with a design matrix of approximately 1 to 5 columns.

I have included a reproducible example below which uses the linear regression example from the pymc3 Getting Started guide, but increases the size of the data to 2 million observations. This produces similar behaviour to the models we are fitting.

When fitting the models using the default MCMC the memory usage during sampling remains stable at about ~10GB during sampling (taking about an hour to sample 2000 draws including warm up).

But when the sampling finishes there is a massive memory spike. The memory usage spikes to 32GB (the maximum on my local machine) for a couple of minutes. Sometimes the machine even runs out of memory.

We are even running AWS EC2 containers with 64GB RAM that are crashing from lack of memory with only about 1 million observations.

My questions:

  • What operation is happening at the end of the sampling that requires such a massive memory spike for a short period of time (e.g. collating the arrays for the multiple chains)?

  • And is there any way we can potentially avoid it?

If not, then it seems we will have to run a memory intensive machine, when the max memory is only being utilised for a ~1% of the total compute time.

Any advice much appreciated! :grinning:

Reproducible example just in case (run time is about one hour though):

import numpy as np
import pymc3 as pm
from pymc3 import *

print('Running on PyMC3 v{}'.format(pm.__version__))  # v3.9.2

size = 2000000

true_intercept = 1
true_slope = 2

x = np.linspace(0, 1, size)
true_regression_line = true_intercept + true_slope * x
y = true_regression_line + np.random.normal(scale=.5, size=size)

data = dict(x=x, y=y)

with Model() as model: # model specifications in PyMC3 are wrapped in a with-statement
    sigma = HalfCauchy('sigma', beta=10, testval=1.)
    intercept = Normal('Intercept', 0, sigma=20)
    x_coeff = Normal('x', 0, sigma=20)
    likelihood = Normal('y', mu=intercept + x_coeff * x, sigma=sigma, observed=y)
    trace = sample(1000, cores=2)

The issue is due to storing pointwise log likelihood values, a step which is done at the end of sampling when calculating ess and rhat. The default is to store such data because it is required for loo/waic calculation and further model comparison.

Taking a look at the answers to

should give more details and guidance on avoiding the issue.

I am also interested in these use cases when due to large number of observations pointwise log likelihood or posterior predictive do not fit in memory. We are working on integrating Dask with ArviZ (see work started on to eventually allow ppc checks and loo/waic calculation for these models.


Awesome! Thanks @OriolAbril for the fast and informative reply.

I will try this out and see how I get on.

fwiw, the loo package in R has a loo.function() method for dealing with this issue. For large datasets loo etc can be calculated iteratively by evaluating loo for each observation in the dataset and summing them as it goes. This avoids having to store the entire pointwise log likelihood matrix in memory at any one time. It is slower, but allows one to avoid the memory issue. Not sure it something similar could be an option for the arviz loo but thought I mention it just incase. Obviously it would mean having the log likelihood function definition for a single observation (not sure that this can be derived from the model so maybe it would have to be explicitly specified by the user), and it also requires having the data stored or provided by the user when they go to evaluate loo. So might not be ideal within the pymc3 / arviz framework. Maybe Dask is a more viable solution.


@OriolAbril Just an update here… I was still having some memory issues even when specifying idata_kwargs = dict(log_likelihood = False).

But I think I have tracked it down not to the behaviour within pymc3.sample() itself but to calling subsequent summary function calls (e.g. pymc3.summary()) on the pymc3 trace that implicitly convert the pymc3 MultiTrace object to an arviz InferenceData object.

My guess is that returning a pymc3 MultiTrace object means that at subsequent function calls (e.g. pymc3.summary()) the object needs to be converted to an InferenceData object and the default log_likelihood = True will be used when doing that conversion (because the idata_kwargs I specify during model fitting aren’t carried forward to the later summary methods).

Returning an InferenceData object at the time the model is estimated seems to avoid the issues altogether. I’m assuming this is because the pointwise log likelihood is never evaluated - not at the end of the sampling and not in any subsequent post-estimation methods (e.g. pymc3.summary()).

So to clarify, I am finding the following encounters no memory issues:

pymc_fit = pymc3.sample(


but the following encounters a memory issue when pymc3.summary() is called:

pymc_fit = pymc3.sample(

print(pymc3.summary(pymc_fit))  # memory issue here

I will just switch to using an InferenceData object to solve this at my end (and as I understand it, that will soon be the pymc3.sample() default anyway). But thought it was worth noting here in case someone else is running into the same issue.

Full reprex below incase you want to easily reproduce the behaviour…

import numpy
import pymc3

sampling_params = dict(
    draws = 30,
    tune = 1,
    chains = 4,
    cores = 4,

size = 10000000

true_intercept = 1
true_slope = 2
x = numpy.linspace(0, 1, size)
true_regression_line = true_intercept + true_slope * x
y = true_regression_line + numpy.random.normal(scale=.5, size=size)

data = dict(x=x, y=y)

with pymc3.Model() as pymc_model:
    intercept = pymc3.Normal('intercept', mu=0, sigma=20)
    beta = pymc3.Normal('beta', mu=0, sigma=20)
    sigma = pymc3.HalfCauchy('sigma', beta=10, testval=1.)
    mu = intercept + beta * data['x']
    likelihood = pymc3.Normal('y', mu=mu, sigma=sigma, observed=data['y'])

pymc_fit1 = pymc3.sample(

print(pymc3.summary(pymc_fit1)) # no memory spike

pymc_fit2 = pymc3.sample(

print(pymc3.summary(pymc_fit2)) # memory spike

Yes, ArviZ functions take InferenceData or something that can be converted to InferenceData as an input. That means that ArviZ functions work when they get PyMC3 MultiTraces, PyStan fits… but they work by converting to InferenceData and using the resulting InferenceData. I can’t recommend strongly enough converting to InferenceData (both using az.from_pymc3 or return_inferencedata=True are fine) and then call ArviZ functions using an InferenceData object. And given that PyMC3 delegates stats and plotting since some versions back to ArviZ, pymc3.summary, pymc3.traceplot, pymc3.loo… are all ArviZ functions.

We recently updated the docs on the pymc3 website to directly link to ArviZ docs, but the internal conversion is probably not well explained anywhere. Using InferenceData will also have other advantages aside from this performance effect. You’ll get an html representation of your object, labeled dimensions and coordinates, automatic broadcasting… The radon notebook is a good example of integrating ArviZ into the workflow. Hopefully the rest of the documentation will also be updated progressively to avoid patterns such as calling ArviZ functions with multitraces.

Regarding the comment about loo.function. Also note that one of ArviZ’s goals is to ease sharing and reproducing the results. Having the inferencedata corresponding to a particular model should be enough to repeat the result analysis and exploration: plots, ppc checks, model comparison… This does not only affect ArviZ-PyMC3 but also PyStan, Pyro, even Turing in Julia, and hopefully inferencedata stored as netCDF will soon be compatible with the posterior R package too.

Supoorting the function as loo does is interesting to minimize memory usage in these cases but it’s not really compatible with sharing the netCDF with a colleague who uses Julia so they can repeat and extend the analysis or publishing the netCDF together with your paper so R users can reproduce your results to make sure their implementation is equivalent before using it in their own models.


Thanks @OriolAbril for the detailed response!

Sounds great. And makes sense about the loo.function. ArviZ is a really cool initiative. It will be awesome for me and my colleagues to be able to switch our estimation engine between pymc3, pyro and pystan and not have to rewrite (or add conditional behaviour throughout) all of our post-estimation workflow.


I was wondering if there is an updated approach to handling this memory spike when calculating the log_likelihood after sampling has completed? I’ve tried adding the following code line to enable Dask usage but am still running out of memory trying to fit the log_likelihood: az.Dask.enable_dask(dask_kwargs={"dask": "parallelized"})

I am also curious about updates to address this problem. Also, would this be the same reason for out-of-memory errors during the “transforming variables” process after sampling with the Numpyro JAX backend?

Lastly, just to confirm, when @OriolAbril says “The default is to store such data because it is required for loo/waic calculation and further model comparison,” does that mean we would not be able to compare these models with others using LOO or WAIC if we set idata_kwargs={"log_likelihood": False}?

This is something that we would like to do and we have discussed integrating the log likelihood computation and storage with Dask, but isn’t available yet. For now, the best work around in my opinion is computing the log likelihood manually and adding it to the inferencedata. There is one example in Refitting PyMC3 models with ArviZ (and xarray) — ArviZ dev documentation, but now with xarray-einstats it is no longer necessary to use apply_ufunc manually.

Assuming you had a linear regression model with student-t distribution as likelihood, it would look similar to:

from xarray_einstats.stats import XrContinuousRV
from scipy import stats

post = idata.posterior
const = idata.constant_data
mu = post["intercept"] + post["beta"] * const["x"]
df = 2.7
dist = XrContinuousRV(stats.t, 2.7, mu, post["sigma"])
log_lik = dist.logpdf(idata.observed_data["y"], dask="allowed"/"parallelized")  # this will be a dataarray
# but inferencedata groups are datasets, not dataarrays
# I am positive parallelized mode will work, but allowed might too, depends
# on scipy internals, if it works it will be more efficient
idata.add_groups(log_likelihood = log_lik.to_dataset(name="y"))

I am not very familiar with numpyro, so I don’t know about this.

Exactly, if you set log_likelihood to false you won’t be able to use waic nor loo unless you compute it manually (as shown above for example, how it is computed doesn’t matter, what matters is the data is there).

Extra note: We are also planning to work on this, but I am not sure az.loo or az.waic will already work with log likelihood arrays that don’t fit in memory. If this is something you are interested in and can help out it’d be very welcome.


Thanks for the example using xarray_einstat! Unfortunately I’m trying to calculate the log_likelihood on a dataset consisting of X1, X2, Y where the length of dataset is 591591. After sampling with only one chain of 1000 samples, and then casting the constant_data and posterior xarrays from float64 to float32, the system needs shape (1, 1000, 591591, 591591) = 1.24 PiB of memory to calculate the step mu = post["α"] + post["β_1"] * const["X_1"] + post["β_2"]*const["X_2"]. Unless there’s another way to break this into chunks?

Naively I’d think the right shape is 1, 1000, 591k, not the 591k twice. What dimensions do you have on the constant data group?

It might help to go through PyMC 4.0 with labeled coords and dims | Oriol unraveled. In xarray the important thing is the dimension name, not its length.

Also, not sure if that was clear already. The data you’ll have in inferencedata will be stored as numpy arrays. You will need to chunk it (aka convert it to dask arrays) before computing the log likelihood. If tge 591k dim is called “obs_id” you can do:

idata = idata.chunk(obs_id=5000)

Important: the chunk size is critical for efficiency. To small will be too slow and too big won’t allow proper parallelization (that needs loading multiple chunks into the ram).

Useful references:

1 Like