Jax-Metal pm.Sample with numpyro on M1/M2

Curious if anyone has been able to successfully sample using Jax-Metal on M1/M2 Apple Silicon? So far I’ve only been able to trade an XLA runtime error for a segmentation fault.

Not really, some folks tried it here: ENH: Get sampling working using Apple Silicon GPU via jax backend · Issue #7332 · pymc-devs/pymc · GitHub

1 Like