The PyMC team would like to announce that we are forking the Aesara project.
This new project is called PyTensor. PyTensor will allow for new features such as labeled arrays, as well as speed up development and streamline the PyMC codebase and user experience. PyTensor is a community focused project and contributions are welcome.
For further details, please seen the full announcement here.
@baggiponte We intend to increase the range of PyTensor graphs (and hence PyMC models) that can be successfully converted to JAX.
That was already a goal with the Aesara team. The main backend for now is still C and will probably become Numba in the future, with JAX as the second most well supported backend.
Do you have specific questions regarding the JAX backend?
Oh right, I totally forgot about compiling to C. That makes a lot of sense, even just for backwards compatibility. Unfortunately I am afraid I am not expert enough to have specific questions yet, but thank you for the prompt reply - can’t wait to see what comes out!
@ricardoV94@aseyboldt With regard to numba, will some of the work from nutpie (i.e. rust + more efficient sampling algorithm) make its way into PyTensor?
Yes, the next release of nutpie (should be coming out in the next couple of days) will use PyTensor instead of aesara.
We are also working on a lot of improvements to the numba backend in pytensor. If things go well nutpie could at some point become the default pymc sampler.
Very cool. Great work on your part. When I have tested nutpie it has been faster than the nuts-jax-numpyro sampler, and just edged out nuts-jax-blackjax. That as default would be insane. I was thinking of getting a gpu-rig with cuda but I might go apple studio ultra with 20 cores and 128gb ram if default nutpie is coming down the pike. The delta between that and single mid level nvidia gpu might not be that bad.