Specifying the GPU

I have more than one GPU in my system and was curious if there is a way to specify the GPU to use for running the PYMC trace? Is this a setting made within JAX or Numpyro/Blackjax?

Someone with experience on multi-gpu systems can chime in, but I think you can set devices using the info found here.