Pymc/numpyro GPU memory allocation

Pymc doesn’t actually need to allocate this memory for inference. It does the allocation when creating the inference data object
By default, pymc will compute the point wise log likelihood, and it does this to enable arviz to run model comparison without needing to have a reference to the model that generated the inference results (it’s model agnostic and can compare results drawn from different PPLs).
Again, this isn’t a hard requirement and you can ask pymc not to compute the point wise log likelihood, which will let you avoid the out of memory error. To do this, you need to pass the following key word argument to sample (or which ever jax variant you choose):
idata_kwargs={“log_likelihood”: False}

3 Likes