Jax GPU access for us Mac users with the new M1/M2 silicon has been non-existent. This might change with the recent release by Apple of jax-metal.
I have begun trying to get it to work but wanted to post this here in case other people wanted to try it as well.
If I get it working, I will post my solution here.
I got the jax-metal library to see my M1 GPU
Almost got it to run through Bambi. But getting an XLA runtime error on compilation. I have no idea what that error means. I used the following at the start of the notebook:
import numpyro numpyro.set_platform(platform=‘METAL’)