I am solving a coin flipping problem with
pymc (version 4). Here is the (very simple) model:
with pm.Model() as coin_model:
p_coin = pm.Beta("p_coin", 2, 2)
heads = pm.Bernoulli("flips", p_coin, observed=data)
p = 0.55
nb_flips = 100
nb_chains = 4
nb_draws = 10000
data = np.random.choice([0,1], size=nb_flips, p=[1-p, p])
model = create_coin(data)
idata = jx.sample_numpyro_nuts(target_accept=0.9, draws=nb_draws, tune=nb_draws, chains=nb_chains, chain_method='parallel')
When I run 20,000 draws with 100,000 coin flips and 4 chains, I am exceeding 10GB of memory on the GPU. I found that pymc preallocates memory in the amount
nb_chains * nb_coin_flips * nb_draws.
My question is why pymc must store this amount of information given that I am only computing the trace of
p_coin? Why must it store the equivalent of all observables at every draw?
Is there a way to reduce memory usage?
Are there routines to track memory usage, not only in JAX, but using pymc’s standard
Thanks for any insight you might provide!
You can set
postprocessing_backend='cpu' to get around that problem.
Thanks. I know imcan operate on the cpu, which is much faster. I am still interested in the rationale for how memory is preallocated on the GPU. thanks.
That’s beyond the scope of PyMC. We simple rely on Numpyro to create a JAX graph from the model logp and sample it. You would have to ask the Numpyro folks for more details, but even they might only forward you to the JAX folks.
Thank you, Ricardo. You are correct of course, but hope springs eternal.
Pymc doesn’t actually need to allocate this memory for inference. It does the allocation when creating the inference data object
By default, pymc will compute the point wise log likelihood, and it does this to enable arviz to run model comparison without needing to have a reference to the model that generated the inference results (it’s model agnostic and can compare results drawn from different PPLs).
Again, this isn’t a hard requirement and you can ask pymc not to compute the point wise log likelihood, which will let you avoid the out of memory error. To do this, you need to pass the following key word argument to
sample (or which ever jax variant you choose):
Thanks. If I only want the posterior, what else should I cut out through proper argument selection?
Just discarding the log likelihood should be fine. The other sampling stats are important to check your result’s quality, and don’t eat up that much memory