I’ve a question regarding the usage of
pm.sampling.jax.sample_blackjax_nuts. I’ve a training dataset with roughly 11K rows and I built a model to leverage the hierarchical structure within it. When I use
pm.sample(return_inferencedata=True), it takes around 4-6 hours to train the model. Since the server I’m using has some GPUs, I decided to use them. However, I keep running in
out of memory errors.
2023-03-07 02:59:02.202448: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2163] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call ‘xla.gpu.custom_call’ failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory.
I’ve tried the following without success:
chain_method='vectorized'instead of the
%env XLA_PYTHON_CLIENT_PREALLOCATE=false(this ends up returning 4 chains with 4 unique values).
I’ve also tried the different solutions provided in this other post with no luck.
Are there any best practices / guides on how to use GPUs with PyMC and avoid the
out of memory error?
Any help is appreciated.