@OriolAbril Just an update here… I was still having some memory issues even when specifying idata_kwargs = dict(log_likelihood = False).
But I think I have tracked it down not to the behaviour within pymc3.sample() itself but to calling subsequent summary function calls (e.g. pymc3.summary()) on the pymc3 trace that implicitly convert the pymc3 MultiTrace object to an arviz InferenceData object.
My guess is that returning a pymc3 MultiTrace object means that at subsequent function calls (e.g. pymc3.summary()) the object needs to be converted to an InferenceData object and the default log_likelihood = True will be used when doing that conversion (because the idata_kwargs I specify during model fitting aren’t carried forward to the later summary methods).
Returning an InferenceData object at the time the model is estimated seems to avoid the issues altogether. I’m assuming this is because the pointwise log likelihood is never evaluated - not at the end of the sampling and not in any subsequent post-estimation methods (e.g. pymc3.summary()).
So to clarify, I am finding the following encounters no memory issues:
pymc_fit = pymc3.sample(
model=pymc_model,
return_inferencedata=True,
idata_kwargs=dict(log_likelihood=False))
print(pymc3.summary(pymc_fit))
but the following encounters a memory issue when pymc3.summary() is called:
pymc_fit = pymc3.sample(
model=pymc_model,
return_inferencedata=False,
idata_kwargs=dict(log_likelihood=False))
print(pymc3.summary(pymc_fit)) # memory issue here
I will just switch to using an InferenceData object to solve this at my end (and as I understand it, that will soon be the pymc3.sample() default anyway). But thought it was worth noting here in case someone else is running into the same issue.
Full reprex below incase you want to easily reproduce the behaviour…
import numpy
import pymc3
sampling_params = dict(
draws = 30,
tune = 1,
chains = 4,
cores = 4,
)
size = 10000000
true_intercept = 1
true_slope = 2
x = numpy.linspace(0, 1, size)
true_regression_line = true_intercept + true_slope * x
y = true_regression_line + numpy.random.normal(scale=.5, size=size)
data = dict(x=x, y=y)
with pymc3.Model() as pymc_model:
intercept = pymc3.Normal('intercept', mu=0, sigma=20)
beta = pymc3.Normal('beta', mu=0, sigma=20)
sigma = pymc3.HalfCauchy('sigma', beta=10, testval=1.)
mu = intercept + beta * data['x']
likelihood = pymc3.Normal('y', mu=mu, sigma=sigma, observed=data['y'])
pymc_fit1 = pymc3.sample(
model=pymc_model,
**sampling_params,
return_inferencedata=True,
idata_kwargs=dict(log_likelihood=False))
print(pymc3.summary(pymc_fit1)) # no memory spike
pymc_fit2 = pymc3.sample(
model=pymc_model,
**sampling_params,
return_inferencedata=False,
idata_kwargs=dict(log_likelihood=False))
print(pymc3.summary(pymc_fit2)) # memory spike