Out of Memory during Transforming Variables Step using Numpyro. Increasing postprocessing_chunks doesn't seems to help. What should I do?

I’m working on a complex hierarchical nested logit model that has a fair amount of parameters. Numpyro appears to be the fastest sampler on CPU for my model. The problem is, with larger number of draws or chains, after hours of sampling, the process always dies during the transforming variables step, and increasing postprocessing_chunks (to 100) doesn’t seem to help. Here is my code

    idata = pm.sample(draws=1000,
                      chains=8,
                      cores=8,
                      tune=500,
                      nuts_sampler="numpyro",
                      idata_kwargs={"log_likelihood": False}, 
                      nuts_sampler_kwargs={"postprocessing_chunks":100},
                      random_seed=1301)

I tried reducing ‘draws’ to 500, and the sampling process did finish. But that was not enough for the model to converge. I’d really appreciate any advices or suggestions. Thank you!

Update pymc, we fixed that recently

Thank you for your quick response. I have been running 5.8.2. I just updated it to 5.9.0. I see in 5.9.0 the “postprocesing_chunks” is depreciated and in it’s place is “postprocessing_vectorize” of which the default option is “scan” . But I’m still running into out of memory issue. Below is the error message. My laptop has 32GB memory. I’d really appreicate your help. Thank you!

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[8], line 193
    185     like_model = pm.Potential("like_model", TranWeight*pm.logp(pm.Categorical.dist(p=p_model),dv_model),dims="obs")
    187     # idata = pm.sample_prior_predictive()
    188     # idata.extend(
    189     #     pm.sample(nuts_sampler="numpyro", idata_kwargs={"log_likelihood": True}, random_seed=1033)
    190     # )
    191     # idata.extend(pm.sample_posterior_predictive(idata))
--> 193     idata = pm.sample(draws=1000,
    194                       chains=8,
    195                       cores=8,
    196                       tune=500,
    197                       nuts_sampler="numpyro",
    198                       idata_kwargs={"log_likelihood": False}, 
    199                       nuts_sampler_kwargs={"postprocessing_chunks":100},
    200                       random_seed=1301)
    202 #nutpie sampler- very slow for some reasons.
    203 # compiled_model = compile_pymc.compile_pymc_model(model)
    204 # idata = nutpie.sample(compiled_model,    
   (...)
    210 #                       save_warmup = False
    211 #                      )

File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\mcmc.py:658, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    654     if not isinstance(step, NUTS):
    655         raise ValueError(
    656             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    657         )
--> 658     return _sample_external_nuts(
    659         sampler=nuts_sampler,
    660         draws=draws,
    661         tune=tune,
    662         chains=chains,
    663         target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    664         random_seed=random_seed,
    665         initvals=initvals,
    666         model=model,
    667         progressbar=progressbar,
    668         idata_kwargs=idata_kwargs,
    669         nuts_sampler_kwargs=nuts_sampler_kwargs,
    670         **kwargs,
    671     )
    673 if isinstance(step, list):
    674     step = CompoundStep(step)

File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\mcmc.py:313, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
    310 elif sampler == "numpyro":
    311     import pymc.sampling.jax as pymc_jax
--> 313     idata = pymc_jax.sample_numpyro_nuts(
    314         draws=draws,
    315         tune=tune,
    316         chains=chains,
    317         target_accept=target_accept,
    318         random_seed=random_seed,
    319         initvals=initvals,
    320         model=model,
    321         progressbar=progressbar,
    322         idata_kwargs=idata_kwargs,
    323         **nuts_sampler_kwargs,
    324     )
    325     return idata
    327 elif sampler == "blackjax":

File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\jax.py:694, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, idata_kwargs, nuts_kwargs, postprocessing_chunks)
    692 print("Transforming variables...", file=sys.stdout)
    693 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
--> 694 result = _postprocess_samples(
    695     jax_fn,
    696     raw_mcmc_samples,
    697     postprocessing_backend=postprocessing_backend,
    698     postprocessing_vectorize=postprocessing_vectorize,
    699 )
    700 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
    702 tic4 = datetime.now()

File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\jax.py:196, in _postprocess_samples(jax_fn, raw_mcmc_samples, postprocessing_backend, postprocessing_vectorize)
    190     jax_vfn = jax.vmap(jax_fn)
    191     _, outs = scan(
    192         lambda _, x: ((), jax_vfn(*x)),
    193         (),
    194         _device_put(t_raw_mcmc_samples, postprocessing_backend),
    195     )
--> 196     return [jnp.swapaxes(t, 0, 1) for t in outs]
    197 elif postprocessing_vectorize == "vmap":
    198     return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))

File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\jax.py:196, in <listcomp>(.0)
    190     jax_vfn = jax.vmap(jax_fn)
    191     _, outs = scan(
    192         lambda _, x: ((), jax_vfn(*x)),
    193         (),
    194         _device_put(t_raw_mcmc_samples, postprocessing_backend),
    195     )
--> 196     return [jnp.swapaxes(t, 0, 1) for t in outs]
    197 elif postprocessing_vectorize == "vmap":
    198     return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))

    [... skipping hidden 10 frame]

File ~\AppData\Local\miniconda3\Lib\site-packages\jax\_src\interpreters\pxla.py:1149, in ExecuteReplicated.__call__(self, *args)
   1147   self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1148 else:
-> 1149   results = self.xla_executable.execute_sharded(input_bufs)
   1150 if dispatch.needs_check_special():
   1151   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 26665600000 bytes.

You might have very large deterministics. Can you conceivably fit them all in RAM at once?

1 Like

Thank you very much Ricardo! Deterministic variables indeed caused it. I was following the example of Discrete Choice and Random Utility Models and set p_ as deterministic variables. After replacing it with just the softmax function, the problem was solved!

1 Like