Is it possible to speed up PyMC sampling?

I will let someone with more jax/gpu experience (@twiecki @ricardoV94 ?) weigh in on the implementation details. As we have discussed, the user guide for v4/jax/gpu is not yet ready. But hopefully we can get you sorted out.