Out of Memory when using pm.sampling.jax.sample_blackjax_nuts


I’ve a question regarding the usage of pm.sampling.jax.sample_blackjax_nuts. I’ve a training dataset with roughly 11K rows and I built a model to leverage the hierarchical structure within it. When I use pm.sample(return_inferencedata=True), it takes around 4-6 hours to train the model. Since the server I’m using has some GPUs, I decided to use them. However, I keep running in out of memory errors.

2023-03-07 02:59:02.202448: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2163] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call ‘xla.gpu.custom_call’ failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory.

I’ve tried the following without success:

  • using chain_method='vectorized' instead of the parallel option,
  • setting %env XLA_PYTHON_CLIENT_MEM_FRACTION=.50,
  • setting %env XLA_PYTHON_CLIENT_PREALLOCATE=false (this ends up returning 4 chains with 4 unique values).

I’ve also tried the different solutions provided in this other post with no luck.

Are there any best practices / guides on how to use GPUs with PyMC and avoid the out of memory error?

Any help is appreciated.

I ended up switching to pm.sampling.jax.sample_numpyro_nuts and setting %env XLA_PYTHON_CLIENT_PREALLOCATE=false (check here). That solved the issue. I still dunno why I was not able to use the Blackjax implementation though.

1 Like

I have similar problems with sample_blackjax_nuts on the GPU, which don’t replicate on the numpyro sampler. I notice that it also has issues on CPUs – not failures, but perhaps something related to poor initialization.