Hi all, I am new to PyMC, I am used to old traditional sampling frameworks like JAGS in R. I am an statistician working on Spatiotemporal models. I am looking forward to use PyMC + JAX + GPU because I am dealing with highly complex models with big spatiotemporal data. I have checked PyMC3 Documentation — PyMC3 3.11.5 documentation webpage which has a very good documentation however I couldn’t find a clear workflow (A to Z) for people like me that are not very familiar with JAX or GPU programming. I was wondering if there is a good vignette/tutorial or documentation on how to fit a model using PyMC+JAX+GPU/TPU (including the necessary checks for the availability of GPU and technical issues around its setup)?
Welcome!
The ability to select from a set of backends (of which JAX is one) is a feature that many are eagerly anticipating. However, this feature is available in PyMC v4 which is currently in beta (though we highly recommend installing it, particularly if you are new to PyMC). Given that and the fact that all of the documentation is currently in the middle of being re-done (e.g., the current state of it can be found here) there are no PyMC v4 “how to” guides that are complete and up-to-date (though there will be as we push toward a final v4 release). That being said, I would suggest looking over the information and documents linked to in this post.
In the meantime, please post whatever questions you might have along the way. There are several people here who are using v4 and JAX in production and can help you iron out whatever issues you might encounter.