Proposal for GSoC 2023: Adding JAX support to PyTensor distributions

Hello PyMC team,

I am an undergrad student who is passionate about probabilistic programming and machine learning. I have been using PyMC and PyTensor for some of my projects and learnings and I really appreciate the work that you have done.

I am writing to propose a project for the Google Summer of Code 2023 program. The project idea is to add JAX support to some of the distributions in PyTensor that currently do not have it. As you know JAX is a library that enables high-performance numerical computing with automatic differentiation and GPU/TPU acceleration. By adding JAX support to PyTensor distributions, we can leverage the benefits of both libraries and enable more efficient and flexible inference and learning.

The project would involve the following tasks:

  • Reviewing the existing codebase of PyTensor and identifying which distributions need JAX support.
  • Implementing JAX versions of the distribution classes and methods, following the coding style and conventions of PyTensor.
  • Writing unit tests and documentation for the new code.
  • Benchmarking the performance and accuracy of the JAX distributions against the original ones.
  • Integrating the JAX distributions with PyMC’s inference algorithms and samplers.

I have some experience with Python, NumPy, Aesara, and JAX. I have also contributed to PyTenson, Aesara and some open source projects in the past. You can find my GitHub profile here: sudarsan2k5

I would love to work on this project with your guidance and feedback. Please let me know if you are interested in mentoring me or if you have any questions or suggestions.

Thank you for your time and consideration.

Sincerely,
Sudarsan Mansingh

2 Likes

Hi @sudarsan2k5,

Thanks for showing interest. Note that we already have a list of proposed projects here: GSoC 2023 projects · pymc-devs/pymc Wiki · GitHub

You seem to be suggesting a new project that is not in that list. That’s fine but may require more work to assess (and find mentors), whereas the written projects were already sketched somewhat and have mentors assigned to them (of course nothing is written in stone at this point).

About your specific suggestion, I am not aware of anything obvious that’s missing in terms of compatibility with specific PyMC distributions and the JAX backend. (although I wouldn’t be shocked if some of the multivariates are not working yet, in which case we should open some issues in our repo).

So I am not saying a custom project in that direction isn’t needed, just that it may require some research on your behalf in advance to find out if there’s actually a need. Otherwise feel free to explore one of the listed projects.

Apologies if you were in contact with another dev already or I missed your project among the list.

1 Like