Batch process capability for pymc.sampling_jax.sample_numpyro_nuts() with GPU?

hi,

I am trying to run a large hierarchical model with pymc + jax on GPU but I keep running into out-of-memory issues. I tried various solutions that I found online but couldn’t really solve the problem so far. Is there anyway to go around this, for example by using batches? I’m very new to pymc and gpu-accelerated computing in general and any help is much appreciated!

The model that I’m trying to estimate has 10,000 respondents and 280 latent variables. It’s a hierarchical model where each respondent completed 12 tasks, from which I was hoping to estimate their individual values for the 280 latent variables. Ideally, I’d like to run 1000 tunes + 1000 draws, but the best that I could do so far before running into the OOM issue was 300 tunes + 300 warms on a GPU with 24G GPU memory.

Here’s my pm model:

basic_model = pm.Model()

with basic_model:

    alphas = pm.Normal("alphas", mu=0, sigma=1, shape=num_parameters)
    sd_distrib = pm.Normal.dist(mu=1, sigma=10, shape=num_parameters)
    L_sigma, corr, stds = pm.LKJCholeskyCov('L_sigma', eta=1, n=num_parameters, sd_dist=sd_distrib, compute_corr=True, store_in_trace=False)
    cov = pm.Deterministic("cov", L_sigma.dot(L_sigma.T))
    betas_init = pm.MvNormal('betas_init', mu=np.zeros((num_parameters)), cov=np.eye(num_parameters), shape=(num_respondents, num_parameters))
    
    
    def logp(choicedata, alphas, betas_init, L_sigma):

        #get transfomred beta from alpha and beta-init
        betas = alphas.reshape(num_respondents,1) + at.dot(betas_init,L_sigma)
        #get respondent-level log likelihood
        resp_log = CHOICE_PROBABILITIES(choicedata, aggregation = True)
      
        return resp_log
    
    
    likelihood = pm.DensityDist(
        "likelihood",
        alphas, betas_init, L_sigma,
        logp=logp,
        observed = ts_choicedata
        )


with basic_model:
    idata = pm.sampling_jax.sample_numpyro_nuts(draws = draws, 
                                                tune=tune, 
                                                target_accept=target_accept, 
                                                chains=chains,
                                                postprocessing_backend = 'cpu',
                                                idata_kwargs = {'log_likelihood': False},
                                                progress_bar=True,
                                               )

The CHOICE_PROBABILITIES function is a customized likelihood function to calculate the loglikelihood at the individual respondent level.

Here are the things that I’m already doing:

  • Add the jax env setting before calling jax:
    %env XLA_PYTHON_CLIENT_PREALLOCATE = false
    %env XLA_PYTHON_CLIENT_ALLOCATOR= platform

  • Use idata_kwargs = {‘log_likelihood’: False} when sampling

  • I also noticed that if I do more tunes and fewer draws, the GPU use (reading from nvidia-smi) seems to be smaller than when I run the same total number of iterations but fewer tunes and more draws.

Thanks a lot!

Yout can try setting chain_method to "sequential" or "vectorized" and see if it does better.

It is normal for less RAM to be needed for tuning than drawing, because tuning samples can be discarded over time.

Thanks a lot for your reply!

I’ve tried that but it didn’t seem to improve much. I’m running more tunes in an attempt to reach the stable posterior before drawing and it seems to improve the accuracy when modelling with my synthetic data.

Are there any other ways to deal with the OOM issue with GPU?

That’s more of a JAX/Numpyro question so you may have more luck asking in their forums.