Pymc/numpyro GPU memory allocation

I am solving a coin flipping problem with pymc (version 4). Here is the (very simple) model:

def create_coin(data):
    with pm.Model() as coin_model:
        p_coin = pm.Beta("p_coin", 2, 2)
        heads = pm.Bernoulli("flips", p_coin, observed=data)
    return coin_model

p = 0.55
nb_flips = 100
nb_chains = 4
nb_draws = 10000
data = np.random.choice([0,1], size=nb_flips, p=[1-p, p])
model = create_coin(data)

with model:
   idata = jx.sample_numpyro_nuts(target_accept=0.9, draws=nb_draws, tune=nb_draws, chains=nb_chains, chain_method='parallel')

When I run 20,000 draws with 100,000 coin flips and 4 chains, I am exceeding 10GB of memory on the GPU. I found that pymc preallocates memory in the amount nb_chains * nb_coin_flips * nb_draws.

My question is why pymc must store this amount of information given that I am only computing the trace of p_coin? Why must it store the equivalent of all observables at every draw?
Is there a way to reduce memory usage?
Are there routines to track memory usage, not only in JAX, but using pymc’s standard sample method?
Thanks for any insight you might provide!

You can set postprocessing_backend='cpu' to get around that problem.

Thanks. I know imcan operate on the cpu, which is much faster. I am still interested in the rationale for how memory is preallocated on the GPU. thanks.

That’s beyond the scope of PyMC. We simple rely on Numpyro to create a JAX graph from the model logp and sample it. You would have to ask the Numpyro folks for more details, but even they might only forward you to the JAX folks.

Thank you, Ricardo. You are correct of course, but hope springs eternal.

2 Likes

Pymc doesn’t actually need to allocate this memory for inference. It does the allocation when creating the inference data object
By default, pymc will compute the point wise log likelihood, and it does this to enable arviz to run model comparison without needing to have a reference to the model that generated the inference results (it’s model agnostic and can compare results drawn from different PPLs).
Again, this isn’t a hard requirement and you can ask pymc not to compute the point wise log likelihood, which will let you avoid the out of memory error. To do this, you need to pass the following key word argument to sample (or which ever jax variant you choose):
idata_kwargs={“log_likelihood”: False}

2 Likes

Thanks. If I only want the posterior, what else should I cut out through proper argument selection?

Just discarding the log likelihood should be fine. The other sampling stats are important to check your result’s quality, and don’t eat up that much memory

Thank you!