Pathfinder Variational Inference Not working on GPU (WSL)

Thank you for your reply!!

The main objective for me is to leverage GPU/JAX/parallel computing to accelerate my code. I have a huge dataset and over 10 million parameters to estimate.

pmx.fit(num_paths=16,num_draws_per_path=1000, num_draws=2000, method="pathfinder", jitter=12, postprocessing_backend='gpu', inference_backend="blackjax",maxiter=10000)

this one does not work for me either…Do you know if variational inference algo from PyMC can leverage GPU or parallel computing? It seems that ADVI doesn’t support GPU, Pathfinder doesn’t either…

1 Like