Hi,
I’ve a question regarding the usage of pm.sampling.jax.sample_blackjax_nuts
. I’ve a training dataset with roughly 11K rows and I built a model to leverage the hierarchical structure within it. When I use pm.sample(return_inferencedata=True)
, it takes around 4-6 hours to train the model. Since the server I’m using has some GPUs, I decided to use them. However, I keep running in out of memory
errors.
2023-03-07 02:59:02.202448: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2163] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call ‘xla.gpu.custom_call’ failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory.
I’ve tried the following without success:
- using
chain_method='vectorized'
instead of theparallel
option, - setting
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.50
, - setting
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
(this ends up returning 4 chains with 4 unique values).
I’ve also tried the different solutions provided in this other post with no luck.
Are there any best practices / guides on how to use GPUs with PyMC and avoid the out of memory
error?
Any help is appreciated.