Has anyone tried using TPU acceleration with the new JAX/XLA numpyro_nuts sampler (Using JAX for faster sampling — PyMC3 3.10.0 documentation)? These Google Edge TPUs from coral.ai seem really affordable: Products | Coral. So, I am tempted to just give it a try, do you think it might work?
It’s a pretty long thread, but there are some discussion on using it on TPUs near the end: Inefficient use of pmap in jax numpyro sampler? · Issue #4288 · pymc-devs/pymc3 · GitHub
TLdr, it works.