Pathfinder Variational Inference Not working on GPU (WSL)

Hi there,

I tried to rerun this Pathfinder Variational Inference:

But it always runs on CPU:

pmx.fit(num_paths=16,num_draws_per_path=1000, num_draws=2000, method="pathfinder", jitter=12, postprocessing_backend='gpu', inference_backend="pymc",
maxiter=10000)

And this one doesn’t work:

pmx.fit(num_paths=16,num_draws_per_path=1000, num_draws=2000, method="pathfinder", jitter=12, postprocessing_backend='gpu', inference_backend="blackjax",
        # maxiter=10000)

The error I received:

ax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last):   File "<string>", line 1, in <module>   

File "/home/xxx/lib/python3.10/site-packages/pymc_extras/inference/fit.py", line 35, in fit     return fit_pathfinder(**kwargs)   
File "/home/xxx/lib/python3.10/site-packages/pymc_extras/inference/pathfinder/pathfinder.py", line 1733, in fit_pathfinder     pathfinder_state, pathfinder_info = blackjax.vi.pathfinder.approximate(   
File "/home/xxxlib/python3.10/site-packages/blackjax/vi/pathfinder.py", line 180, in approximate     elbo, beta, gamma = jax.vmap(path_finder_body_fn)(   
File "/home/xxx/lib/python3.10/site-packages/blackjax/vi/pathfinder.py", line 159, in path_finder_body_fn     phi, logq = bfgs_sample(   
File "/home/xxx/python3.10/site-packages/blackjax/optimizers/lbfgs.py", line 353, in bfgs_sample     Q, R = jnp.linalg.qr(jnp.diag(jnp.sqrt(1 / alpha)) @ beta) jaxlib.xla_extension.XlaRuntimeError: INTERNAL: jaxlib/gpu/solver_interface.cc:125: operation cusolverDnDorgqr(handle, m, n, k, a, m, tau, workspace, lwork, info) failed: cuSolver internal error

It is not a pre allocation issue as what others demostrated on the web, I have the following settings:

os.environ['JAX_PLATFORM_NAME'] = 'gpu'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
jax.config.update("jax_enable_x64", True)
os.environ["JAX_TRACEBACK_FILTERING"]="off"

And this is mu cuda settings:

| NVIDIA-SMI 572.16                 Driver Version: 572.16         CUDA Version: 12.8   

I feel it might be some issue with blackjax. Any help is appreciated!! Thank you in advance!!

2 Likes

Hey Dennis,

Sorry about this – we have recently updated the Pathfinder implementation, which no longer uses the Blackjax library by default. Once we cut a new release of pymc-examples (which will hopefully happen in the next few days), the notebook will be updated also. If you want to use the Blackjax implementation, adding an inference_backend="blackjax" argument to pmx.fit should work. The PyMC implementation does not yet use JAX, but has other nice features that the Blackjax version does not have, such as multi-path Pathfinder.

2 Likes

Thank you for your reply!!

The main objective for me is to leverage GPU/JAX/parallel computing to accelerate my code. I have a huge dataset and over 10 million parameters to estimate.

pmx.fit(num_paths=16,num_draws_per_path=1000, num_draws=2000, method="pathfinder", jitter=12, postprocessing_backend='gpu', inference_backend="blackjax",maxiter=10000)

this one does not work for me either…Do you know if variational inference algo from PyMC can leverage GPU or parallel computing? It seems that ADVI doesn’t support GPU, Pathfinder doesn’t either…

1 Like

Hi @Dennis_Ma,

Best to use the pymc backend for Pathfinder for now. There’s still a few more developments under way for blackjax. The pymc backend for Pathfinder is fully functional–with Multi-Path Pathfinder available!

ATM, most of the argument values you added in the blackjax fit above are ignored like num_paths, num_draws_per_path.

It could be related to the CUDA solver hitting memory limits with your 10M parameters. It is possible that, after making improvements on the blackjax backend, some of the compute operations are optimised a bit better to make it more memory efficient. Its hard to know at this stage if you’d still encounter the CUDA memory limits problem after these changes.

Its difficult to decompose the ADVI algorithm into parallel computations because most of the update and evaluate procedures happen sequentially. As for leveraging GPU, I’d say no (I think) because getting JAX + JIT + ADVI to work together has been a bit tough to figure out.

The Pathfinder algorithm is “embarrassingly parallel”, and the pymc backend supports parallel computing on a single machine with concurrent="thread" and concurrent="process". The default setting is concurrent=None. Though, from limited tests, I didn’t see much speed-up improvements possibly due to the additional overhead of currency and the models I was trying it on were small.

A new pymc-examples notebook for Pathfinder will be updated soon (possibly next week).

3 Likes

Thank you so much for the detailed explanation!! Looking forward to the new Pathfinder example notebook!

1 Like