Accessing GPU on a Windows 11 PC

Has anyone been able to use the GPU on a Windows 11 machine for sampling in pymc, or know if it is possible? With Jax I can only access the CPU. I have GPU access from PyTorch and TensorFlow but I don’t know if I can use these in the pymc sampling methods. If anyone has any advice or access to documentation it will be appreciated.

I don’t think there’s anything for your case, although we are working on adding a PyTorch backend to PyTensor that would allow exploiting the GPU. That’s still in the workings though Add initial support for PyTorch backend by HarshvirSandhu · Pull Request #764 · pymc-devs/pytensor · GitHub

1 Like

You can do GPU sampling via WLS2, but there’s some song and dance involved. See here and here for instructions.

For what it’s worth, I’ve never had good experiences with GPU sampling via JAX.

1 Like

Depends on the model and if jax is actually finding the GPU but we have seen many models that can only be sampled reasonably fast on the GPU

Thank you for the reply. I am new to the discourse and just learning how this works. Is there a way for me to subscribe to this subject so I can be aware if this feature becomes available?

Thank you. I reviewed a lot of materials but did not come across this one. I was not sure what
WSL meant in ‘Windows WSL2 x86_64’ in the Jax Installation instructions. The Microsoft link you provided clears that up. I work with these materials and see if I can get it working on my machine. I bought the laptop I have specifically for the NVIDIA GeForce RTX 4060 Laptop GPU to use with PyMC so it’s a bit of a bummer I might not be able to make good use of it.

You can follow specific PRs in GitHub if you have an account… Otherwise I don’t know. We will probably make a couple of announcements in social media when it’s ready