Thank you for your quick response. I have been running 5.8.2. I just updated it to 5.9.0. I see in 5.9.0 the “postprocesing_chunks” is depreciated and in it’s place is “postprocessing_vectorize” of which the default option is “scan” . But I’m still running into out of memory issue. Below is the error message. My laptop has 32GB memory. I’d really appreicate your help. Thank you!
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In[8], line 193
185 like_model = pm.Potential("like_model", TranWeight*pm.logp(pm.Categorical.dist(p=p_model),dv_model),dims="obs")
187 # idata = pm.sample_prior_predictive()
188 # idata.extend(
189 # pm.sample(nuts_sampler="numpyro", idata_kwargs={"log_likelihood": True}, random_seed=1033)
190 # )
191 # idata.extend(pm.sample_posterior_predictive(idata))
--> 193 idata = pm.sample(draws=1000,
194 chains=8,
195 cores=8,
196 tune=500,
197 nuts_sampler="numpyro",
198 idata_kwargs={"log_likelihood": False},
199 nuts_sampler_kwargs={"postprocessing_chunks":100},
200 random_seed=1301)
202 #nutpie sampler- very slow for some reasons.
203 # compiled_model = compile_pymc.compile_pymc_model(model)
204 # idata = nutpie.sample(compiled_model,
(...)
210 # save_warmup = False
211 # )
File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\mcmc.py:658, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
654 if not isinstance(step, NUTS):
655 raise ValueError(
656 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
657 )
--> 658 return _sample_external_nuts(
659 sampler=nuts_sampler,
660 draws=draws,
661 tune=tune,
662 chains=chains,
663 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
664 random_seed=random_seed,
665 initvals=initvals,
666 model=model,
667 progressbar=progressbar,
668 idata_kwargs=idata_kwargs,
669 nuts_sampler_kwargs=nuts_sampler_kwargs,
670 **kwargs,
671 )
673 if isinstance(step, list):
674 step = CompoundStep(step)
File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\mcmc.py:313, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
310 elif sampler == "numpyro":
311 import pymc.sampling.jax as pymc_jax
--> 313 idata = pymc_jax.sample_numpyro_nuts(
314 draws=draws,
315 tune=tune,
316 chains=chains,
317 target_accept=target_accept,
318 random_seed=random_seed,
319 initvals=initvals,
320 model=model,
321 progressbar=progressbar,
322 idata_kwargs=idata_kwargs,
323 **nuts_sampler_kwargs,
324 )
325 return idata
327 elif sampler == "blackjax":
File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\jax.py:694, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, idata_kwargs, nuts_kwargs, postprocessing_chunks)
692 print("Transforming variables...", file=sys.stdout)
693 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
--> 694 result = _postprocess_samples(
695 jax_fn,
696 raw_mcmc_samples,
697 postprocessing_backend=postprocessing_backend,
698 postprocessing_vectorize=postprocessing_vectorize,
699 )
700 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
702 tic4 = datetime.now()
File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\jax.py:196, in _postprocess_samples(jax_fn, raw_mcmc_samples, postprocessing_backend, postprocessing_vectorize)
190 jax_vfn = jax.vmap(jax_fn)
191 _, outs = scan(
192 lambda _, x: ((), jax_vfn(*x)),
193 (),
194 _device_put(t_raw_mcmc_samples, postprocessing_backend),
195 )
--> 196 return [jnp.swapaxes(t, 0, 1) for t in outs]
197 elif postprocessing_vectorize == "vmap":
198 return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))
File ~\AppData\Local\miniconda3\Lib\site-packages\pymc\sampling\jax.py:196, in <listcomp>(.0)
190 jax_vfn = jax.vmap(jax_fn)
191 _, outs = scan(
192 lambda _, x: ((), jax_vfn(*x)),
193 (),
194 _device_put(t_raw_mcmc_samples, postprocessing_backend),
195 )
--> 196 return [jnp.swapaxes(t, 0, 1) for t in outs]
197 elif postprocessing_vectorize == "vmap":
198 return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))
[... skipping hidden 10 frame]
File ~\AppData\Local\miniconda3\Lib\site-packages\jax\_src\interpreters\pxla.py:1149, in ExecuteReplicated.__call__(self, *args)
1147 self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
1148 else:
-> 1149 results = self.xla_executable.execute_sharded(input_bufs)
1150 if dispatch.needs_check_special():
1151 out_arrays = results.disassemble_into_single_device_arrays()
XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 26665600000 bytes.