Hello PyMC team,
I am an undergrad student who is passionate about probabilistic programming and machine learning. I have been using PyMC and PyTensor for some of my projects and learnings and I really appreciate the work that you have done.
I am writing to propose a project for the Google Summer of Code 2023 program. The project idea is to add JAX support to some of the distributions in PyTensor that currently do not have it. As you know JAX is a library that enables high-performance numerical computing with automatic differentiation and GPU/TPU acceleration. By adding JAX support to PyTensor distributions, we can leverage the benefits of both libraries and enable more efficient and flexible inference and learning.
The project would involve the following tasks:
- Reviewing the existing codebase of PyTensor and identifying which distributions need JAX support.
- Implementing JAX versions of the distribution classes and methods, following the coding style and conventions of PyTensor.
- Writing unit tests and documentation for the new code.
- Benchmarking the performance and accuracy of the JAX distributions against the original ones.
- Integrating the JAX distributions with PyMC’s inference algorithms and samplers.
I have some experience with Python, NumPy, Aesara, and JAX. I have also contributed to PyTenson, Aesara and some open source projects in the past. You can find my GitHub profile here: sudarsan2k5
I would love to work on this project with your guidance and feedback. Please let me know if you are interested in mentoring me or if you have any questions or suggestions.
Thank you for your time and consideration.
Sincerely,
Sudarsan Mansingh