hi,
I am trying to run a large hierarchical model with pymc + jax on GPU but I keep running into outofmemory issues. I tried various solutions that I found online but couldn’t really solve the problem so far. Is there anyway to go around this, for example by using batches? I’m very new to pymc and gpuaccelerated computing in general and any help is much appreciated!
The model that I’m trying to estimate has 10,000 respondents and 280 latent variables. It’s a hierarchical model where each respondent completed 12 tasks, from which I was hoping to estimate their individual values for the 280 latent variables. Ideally, I’d like to run 1000 tunes + 1000 draws, but the best that I could do so far before running into the OOM issue was 300 tunes + 300 warms on a GPU with 24G GPU memory.
Here’s my pm model:
basic_model = pm.Model()
with basic_model:
alphas = pm.Normal("alphas", mu=0, sigma=1, shape=num_parameters)
sd_distrib = pm.Normal.dist(mu=1, sigma=10, shape=num_parameters)
L_sigma, corr, stds = pm.LKJCholeskyCov('L_sigma', eta=1, n=num_parameters, sd_dist=sd_distrib, compute_corr=True, store_in_trace=False)
cov = pm.Deterministic("cov", L_sigma.dot(L_sigma.T))
betas_init = pm.MvNormal('betas_init', mu=np.zeros((num_parameters)), cov=np.eye(num_parameters), shape=(num_respondents, num_parameters))
def logp(choicedata, alphas, betas_init, L_sigma):
#get transfomred beta from alpha and betainit
betas = alphas.reshape(num_respondents,1) + at.dot(betas_init,L_sigma)
#get respondentlevel log likelihood
resp_log = CHOICE_PROBABILITIES(choicedata, aggregation = True)
return resp_log
likelihood = pm.DensityDist(
"likelihood",
alphas, betas_init, L_sigma,
logp=logp,
observed = ts_choicedata
)
with basic_model:
idata = pm.sampling_jax.sample_numpyro_nuts(draws = draws,
tune=tune,
target_accept=target_accept,
chains=chains,
postprocessing_backend = 'cpu',
idata_kwargs = {'log_likelihood': False},
progress_bar=True,
)
The CHOICE_PROBABILITIES function is a customized likelihood function to calculate the loglikelihood at the individual respondent level.
Here are the things that I’m already doing:

Add the jax env setting before calling jax:
%env XLA_PYTHON_CLIENT_PREALLOCATE = false
%env XLA_PYTHON_CLIENT_ALLOCATOR= platform 
Use idata_kwargs = {‘log_likelihood’: False} when sampling

I also noticed that if I do more tunes and fewer draws, the GPU use (reading from nvidiasmi) seems to be smaller than when I run the same total number of iterations but fewer tunes and more draws.
Thanks a lot!