Out of Memory during Transforming Variables Step using Numpyro. Increasing postprocessing_chunks doesn't seems to help. What should I do?

Thank you for your quick response. I have been running 5.8.2. I just updated it to 5.9.0. I see in 5.9.0 the “postprocesing_chunks” is depreciated and in it’s place is “postprocessing_vectorize” of which the default option is “scan” . But I’m still running into out of memory issue. Below is the error message. My laptop has 32GB memory. I’d really appreicate your help. Thank you!

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[8], line 193
    185     like_model = pm.Potential("like_model", TranWeight*pm.logp(pm.Categorical.dist(p=p_model),dv_model),dims="obs")
    187     # idata = pm.sample_prior_predictive()
    188     # idata.extend(
    189     #     pm.sample(nuts_sampler="numpyro", idata_kwargs={"log_likelihood": True}, random_seed=1033)
    190     # )
    191     # idata.extend(pm.sample_posterior_predictive(idata))
--> 193     idata = pm.sample(draws=1000,
    194                       chains=8,
    195                       cores=8,
    196                       tune=500,
    197                       nuts_sampler="numpyro",
    198                       idata_kwargs={"log_likelihood": False}, 
    199                       nuts_sampler_kwargs={"postprocessing_chunks":100},
    200                       random_seed=1301)
    202 #nutpie sampler- very slow for some reasons.
    203 # compiled_model = compile_pymc.compile_pymc_model(model)
    204 # idata = nutpie.sample(compiled_model,    
   (...)
    210 #                       save_warmup = False
    211 #                      )

File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\mcmc.py:658, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    654     if not isinstance(step, NUTS):
    655         raise ValueError(
    656             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    657         )
--> 658     return _sample_external_nuts(
    659         sampler=nuts_sampler,
    660         draws=draws,
    661         tune=tune,
    662         chains=chains,
    663         target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    664         random_seed=random_seed,
    665         initvals=initvals,
    666         model=model,
    667         progressbar=progressbar,
    668         idata_kwargs=idata_kwargs,
    669         nuts_sampler_kwargs=nuts_sampler_kwargs,
    670         **kwargs,
    671     )
    673 if isinstance(step, list):
    674     step = CompoundStep(step)

File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\mcmc.py:313, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
    310 elif sampler == "numpyro":
    311     import pymc.sampling.jax as pymc_jax
--> 313     idata = pymc_jax.sample_numpyro_nuts(
    314         draws=draws,
    315         tune=tune,
    316         chains=chains,
    317         target_accept=target_accept,
    318         random_seed=random_seed,
    319         initvals=initvals,
    320         model=model,
    321         progressbar=progressbar,
    322         idata_kwargs=idata_kwargs,
    323         **nuts_sampler_kwargs,
    324     )
    325     return idata
    327 elif sampler == "blackjax":

File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\jax.py:694, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, idata_kwargs, nuts_kwargs, postprocessing_chunks)
    692 print("Transforming variables...", file=sys.stdout)
    693 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
--> 694 result = _postprocess_samples(
    695     jax_fn,
    696     raw_mcmc_samples,
    697     postprocessing_backend=postprocessing_backend,
    698     postprocessing_vectorize=postprocessing_vectorize,
    699 )
    700 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
    702 tic4 = datetime.now()

File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\jax.py:196, in _postprocess_samples(jax_fn, raw_mcmc_samples, postprocessing_backend, postprocessing_vectorize)
    190     jax_vfn = jax.vmap(jax_fn)
    191     _, outs = scan(
    192         lambda _, x: ((), jax_vfn(*x)),
    193         (),
    194         _device_put(t_raw_mcmc_samples, postprocessing_backend),
    195     )
--> 196     return [jnp.swapaxes(t, 0, 1) for t in outs]
    197 elif postprocessing_vectorize == "vmap":
    198     return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))

File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\jax.py:196, in <listcomp>(.0)
    190     jax_vfn = jax.vmap(jax_fn)
    191     _, outs = scan(
    192         lambda _, x: ((), jax_vfn(*x)),
    193         (),
    194         _device_put(t_raw_mcmc_samples, postprocessing_backend),
    195     )
--> 196     return [jnp.swapaxes(t, 0, 1) for t in outs]
    197 elif postprocessing_vectorize == "vmap":
    198     return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))

    [... skipping hidden 10 frame]

File ~\AppData\Local\miniconda3\Lib\site-packages\jax\_src\interpreters\pxla.py:1149, in ExecuteReplicated.__call__(self, *args)
   1147   self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1148 else:
-> 1149   results = self.xla_executable.execute_sharded(input_bufs)
   1150 if dispatch.needs_check_special():
   1151   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 26665600000 bytes.