Hardware acceleration with Google Edge TPUs

Has anyone tried using TPU acceleration with the new JAX/XLA numpyro_nuts sampler (Using JAX for faster sampling β€” PyMC3 3.10.0 documentation)? These Google Edge TPUs from coral.ai seem really affordable: Products | Coral. So, I am tempted to just give it a try, do you think it might work?

It’s a pretty long thread, but there are some discussion on using it on TPUs near the end: Inefficient use of pmap in jax numpyro sampler? Β· Issue #4288 Β· pymc-devs/pymc3 Β· GitHub
TLdr, it works.