Hi there,
I tried to rerun this Pathfinder Variational Inference:
But it always runs on CPU:
pmx.fit(num_paths=16,num_draws_per_path=1000, num_draws=2000, method="pathfinder", jitter=12, postprocessing_backend='gpu', inference_backend="pymc",
maxiter=10000)
And this one doesn’t work:
pmx.fit(num_paths=16,num_draws_per_path=1000, num_draws=2000, method="pathfinder", jitter=12, postprocessing_backend='gpu', inference_backend="blackjax",
# maxiter=10000)
The error I received:
ax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "<string>", line 1, in <module>
File "/home/xxx/lib/python3.10/site-packages/pymc_extras/inference/fit.py", line 35, in fit return fit_pathfinder(**kwargs)
File "/home/xxx/lib/python3.10/site-packages/pymc_extras/inference/pathfinder/pathfinder.py", line 1733, in fit_pathfinder pathfinder_state, pathfinder_info = blackjax.vi.pathfinder.approximate(
File "/home/xxxlib/python3.10/site-packages/blackjax/vi/pathfinder.py", line 180, in approximate elbo, beta, gamma = jax.vmap(path_finder_body_fn)(
File "/home/xxx/lib/python3.10/site-packages/blackjax/vi/pathfinder.py", line 159, in path_finder_body_fn phi, logq = bfgs_sample(
File "/home/xxx/python3.10/site-packages/blackjax/optimizers/lbfgs.py", line 353, in bfgs_sample Q, R = jnp.linalg.qr(jnp.diag(jnp.sqrt(1 / alpha)) @ beta) jaxlib.xla_extension.XlaRuntimeError: INTERNAL: jaxlib/gpu/solver_interface.cc:125: operation cusolverDnDorgqr(handle, m, n, k, a, m, tau, workspace, lwork, info) failed: cuSolver internal error
It is not a pre allocation issue as what others demostrated on the web, I have the following settings:
os.environ['JAX_PLATFORM_NAME'] = 'gpu'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
jax.config.update("jax_enable_x64", True)
os.environ["JAX_TRACEBACK_FILTERING"]="off"
And this is mu cuda settings:
| NVIDIA-SMI 572.16 Driver Version: 572.16 CUDA Version: 12.8
I feel it might be some issue with blackjax. Any help is appreciated!! Thank you in advance!!