Memory spike at the end of the MCMC sampling

Hi there

Our workplace is fitting relatively simple GLM models to datasets with approx 1 million to 10 million observations.

The models themselves aren’t overly complex – either simple linear regression or GLMs with a design matrix of approximately 1 to 5 columns.

I have included a reproducible example below which uses the linear regression example from the pymc3 Getting Started guide, but increases the size of the data to 2 million observations. This produces similar behaviour to the models we are fitting.

When fitting the models using the default MCMC the memory usage during sampling remains stable at about ~10GB during sampling (taking about an hour to sample 2000 draws including warm up).

But when the sampling finishes there is a massive memory spike. The memory usage spikes to 32GB (the maximum on my local machine) for a couple of minutes. Sometimes the machine even runs out of memory.

We are even running AWS EC2 containers with 64GB RAM that are crashing from lack of memory with only about 1 million observations.

My questions:

  • What operation is happening at the end of the sampling that requires such a massive memory spike for a short period of time (e.g. collating the arrays for the multiple chains)?

  • And is there any way we can potentially avoid it?

If not, then it seems we will have to run a memory intensive machine, when the max memory is only being utilised for a ~1% of the total compute time.

Any advice much appreciated! :grinning:

Reproducible example just in case (run time is about one hour though):

import numpy as np
import pymc3 as pm
from pymc3 import *

print('Running on PyMC3 v{}'.format(pm.__version__))  # v3.9.2

size = 2000000

true_intercept = 1
true_slope = 2

x = np.linspace(0, 1, size)
true_regression_line = true_intercept + true_slope * x
y = true_regression_line + np.random.normal(scale=.5, size=size)

data = dict(x=x, y=y)

with Model() as model: # model specifications in PyMC3 are wrapped in a with-statement
    sigma = HalfCauchy('sigma', beta=10, testval=1.)
    intercept = Normal('Intercept', 0, sigma=20)
    x_coeff = Normal('x', 0, sigma=20)
    likelihood = Normal('y', mu=intercept + x_coeff * x, sigma=sigma, observed=y)
    trace = sample(1000, cores=2)
2 Likes

The issue is due to storing pointwise log likelihood values, a step which is done at the end of sampling when calculating ess and rhat. The default is to store such data because it is required for loo/waic calculation and further model comparison.

Taking a look at the answers to

should give more details and guidance on avoiding the issue.

I am also interested in these use cases when due to large number of observations pointwise log likelihood or posterior predictive do not fit in memory. We are working on integrating Dask with ArviZ (see work started on https://github.com/arviz-devs/arviz/pull/1229) to eventually allow ppc checks and loo/waic calculation for these models.

2 Likes

Awesome! Thanks @OriolAbril for the fast and informative reply.

I will try this out and see how I get on.

fwiw, the loo package in R has a loo.function() method for dealing with this issue. For large datasets loo etc can be calculated iteratively by evaluating loo for each observation in the dataset and summing them as it goes. This avoids having to store the entire pointwise log likelihood matrix in memory at any one time. It is slower, but allows one to avoid the memory issue. Not sure it something similar could be an option for the arviz loo but thought I mention it just incase. Obviously it would mean having the log likelihood function definition for a single observation (not sure that this can be derived from the model so maybe it would have to be explicitly specified by the user), and it also requires having the data stored or provided by the user when they go to evaluate loo. So might not be ideal within the pymc3 / arviz framework. Maybe Dask is a more viable solution.

2 Likes

@OriolAbril Just an update here… I was still having some memory issues even when specifying idata_kwargs = dict(log_likelihood = False).

But I think I have tracked it down not to the behaviour within pymc3.sample() itself but to calling subsequent summary function calls (e.g. pymc3.summary()) on the pymc3 trace that implicitly convert the pymc3 MultiTrace object to an arviz InferenceData object.

My guess is that returning a pymc3 MultiTrace object means that at subsequent function calls (e.g. pymc3.summary()) the object needs to be converted to an InferenceData object and the default log_likelihood = True will be used when doing that conversion (because the idata_kwargs I specify during model fitting aren’t carried forward to the later summary methods).

Returning an InferenceData object at the time the model is estimated seems to avoid the issues altogether. I’m assuming this is because the pointwise log likelihood is never evaluated - not at the end of the sampling and not in any subsequent post-estimation methods (e.g. pymc3.summary()).

So to clarify, I am finding the following encounters no memory issues:

pymc_fit = pymc3.sample(
    model=pymc_model,
    return_inferencedata=True,
    idata_kwargs=dict(log_likelihood=False))

print(pymc3.summary(pymc_fit))

but the following encounters a memory issue when pymc3.summary() is called:

pymc_fit = pymc3.sample(
    model=pymc_model,
    return_inferencedata=False,
    idata_kwargs=dict(log_likelihood=False))

print(pymc3.summary(pymc_fit))  # memory issue here

I will just switch to using an InferenceData object to solve this at my end (and as I understand it, that will soon be the pymc3.sample() default anyway). But thought it was worth noting here in case someone else is running into the same issue.

Full reprex below incase you want to easily reproduce the behaviour…

import numpy
import pymc3

sampling_params = dict(
    draws = 30,
    tune = 1,
    chains = 4,
    cores = 4,
)

size = 10000000

true_intercept = 1
true_slope = 2
x = numpy.linspace(0, 1, size)
true_regression_line = true_intercept + true_slope * x
y = true_regression_line + numpy.random.normal(scale=.5, size=size)

data = dict(x=x, y=y)

with pymc3.Model() as pymc_model:
    intercept = pymc3.Normal('intercept', mu=0, sigma=20)
    beta = pymc3.Normal('beta', mu=0, sigma=20)
    sigma = pymc3.HalfCauchy('sigma', beta=10, testval=1.)
    mu = intercept + beta * data['x']
    likelihood = pymc3.Normal('y', mu=mu, sigma=sigma, observed=data['y'])

pymc_fit1 = pymc3.sample(
    model=pymc_model,
    **sampling_params,
    return_inferencedata=True,
    idata_kwargs=dict(log_likelihood=False))

print(pymc3.summary(pymc_fit1)) # no memory spike

pymc_fit2 = pymc3.sample(
    model=pymc_model,
    **sampling_params,
    return_inferencedata=False,
    idata_kwargs=dict(log_likelihood=False))

print(pymc3.summary(pymc_fit2)) # memory spike
2 Likes

Yes, ArviZ functions take InferenceData or something that can be converted to InferenceData as an input. That means that ArviZ functions work when they get PyMC3 MultiTraces, PyStan fits… but they work by converting to InferenceData and using the resulting InferenceData. I can’t recommend strongly enough converting to InferenceData (both using az.from_pymc3 or return_inferencedata=True are fine) and then call ArviZ functions using an InferenceData object. And given that PyMC3 delegates stats and plotting since some versions back to ArviZ, pymc3.summary, pymc3.traceplot, pymc3.loo… are all ArviZ functions.

We recently updated the docs on the pymc3 website to directly link to ArviZ docs, but the internal conversion is probably not well explained anywhere. Using InferenceData will also have other advantages aside from this performance effect. You’ll get an html representation of your object, labeled dimensions and coordinates, automatic broadcasting… The radon notebook is a good example of integrating ArviZ into the workflow. Hopefully the rest of the documentation will also be updated progressively to avoid patterns such as calling ArviZ functions with multitraces.


Regarding the comment about loo.function. Also note that one of ArviZ’s goals is to ease sharing and reproducing the results. Having the inferencedata corresponding to a particular model should be enough to repeat the result analysis and exploration: plots, ppc checks, model comparison… This does not only affect ArviZ-PyMC3 but also PyStan, Pyro, even Turing in Julia, and hopefully inferencedata stored as netCDF will soon be compatible with the posterior R package too.

Supoorting the function as loo does is interesting to minimize memory usage in these cases but it’s not really compatible with sharing the netCDF with a colleague who uses Julia so they can repeat and extend the analysis or publishing the netCDF together with your paper so R users can reproduce your results to make sure their implementation is equivalent before using it in their own models.

2 Likes

Thanks @OriolAbril for the detailed response!

Sounds great. And makes sense about the loo.function. ArviZ is a really cool initiative. It will be awesome for me and my colleagues to be able to switch our estimation engine between pymc3, pyro and pystan and not have to rewrite (or add conditional behaviour throughout) all of our post-estimation workflow.

2 Likes