Jax on Apple Silicon GPU

Jax GPU access for us Mac users with the new M1/M2 silicon has been non-existent. This might change with the recent release by Apple of jax-metal.

I have begun trying to get it to work but wanted to post this here in case other people wanted to try it as well.

If I get it working, I will post my solution here.

3 Likes

I got the jax-metal library to see my M1 GPU

Screenshot 2023-06-13 at 2.22.18 AM

1 Like

Almost got it to run through Bambi. But getting an XLA runtime error on compilation. I have no idea what that error means. I used the following at the start of the notebook:

import numpyro
numpyro.set_platform(platform=‘METAL’)