hi,
I am trying to run a large hierarchical model with pymc + jax on GPU but I keep running into out-of-memory issues. I tried various solutions that I found online but couldn’t really solve the problem so far. Is there anyway to go around this, for example by using batches? I’m very new to pymc and gpu-accelerated computing in general and any help is much appreciated!
The model that I’m trying to estimate has 10,000 respondents and 280 latent variables. It’s a hierarchical model where each respondent completed 12 tasks, from which I was hoping to estimate their individual values for the 280 latent variables. Ideally, I’d like to run 1000 tunes + 1000 draws, but the best that I could do so far before running into the OOM issue was 300 tunes + 300 warms on a GPU with 24G GPU memory.
Here’s my pm model:
basic_model = pm.Model()
with basic_model:
alphas = pm.Normal("alphas", mu=0, sigma=1, shape=num_parameters)
sd_distrib = pm.Normal.dist(mu=1, sigma=10, shape=num_parameters)
L_sigma, corr, stds = pm.LKJCholeskyCov('L_sigma', eta=1, n=num_parameters, sd_dist=sd_distrib, compute_corr=True, store_in_trace=False)
cov = pm.Deterministic("cov", L_sigma.dot(L_sigma.T))
betas_init = pm.MvNormal('betas_init', mu=np.zeros((num_parameters)), cov=np.eye(num_parameters), shape=(num_respondents, num_parameters))
def logp(choicedata, alphas, betas_init, L_sigma):
#get transfomred beta from alpha and beta-init
betas = alphas.reshape(num_respondents,1) + at.dot(betas_init,L_sigma)
#get respondent-level log likelihood
resp_log = CHOICE_PROBABILITIES(choicedata, aggregation = True)
return resp_log
likelihood = pm.DensityDist(
"likelihood",
alphas, betas_init, L_sigma,
logp=logp,
observed = ts_choicedata
)
with basic_model:
idata = pm.sampling_jax.sample_numpyro_nuts(draws = draws,
tune=tune,
target_accept=target_accept,
chains=chains,
postprocessing_backend = 'cpu',
idata_kwargs = {'log_likelihood': False},
progress_bar=True,
)
The CHOICE_PROBABILITIES function is a customized likelihood function to calculate the loglikelihood at the individual respondent level.
Here are the things that I’m already doing:
-
Add the jax env setting before calling jax:
%env XLA_PYTHON_CLIENT_PREALLOCATE = false
%env XLA_PYTHON_CLIENT_ALLOCATOR= platform -
Use idata_kwargs = {‘log_likelihood’: False} when sampling
-
I also noticed that if I do more tunes and fewer draws, the GPU use (reading from nvidia-smi) seems to be smaller than when I run the same total number of iterations but fewer tunes and more draws.
Thanks a lot!