Topic about the outreachy project proposal titled Add adaptation to BlackJAX samplers. This topic is aimed at outreachy applicants and describes the work for applicants during the application process.
As you can see in the project description, the only hard requirement to apply to this project is being experimented working with the GitHub workflow, forking repos and using multiple branches to work on multiple features at once to eventually submit a PR that can be merged into the main
branch. In pymc3, aesara, blackjax… there are also some other requirements to be met while working, code style and formatting, testing, collaboration guidelines… These are described in the contributing guide .
In order to be considered for this project, applicants must have submitted two pull requests (or a single large pull request) to any of pymc3, aesara, blackjax (not recommended yet) or pymc-examples. This will help you familiarize with the contributing workflow, which may look straightforward when written down as a list in the contributing guide but always requires some time and practice to get familiar and comfortable with it. Having this requirement also allows us to see that you meet the required skills for this project in practice and ensures that you’ll be able to start working once the internship begins.
In addition, you should also work on the assignment below about the suggested skills. This will help you both start familiarizing with some of the libraries involved as well as building skills that you’ll need for the project. The assignment below focuses mostly on skills like matrix manipulation and random number generation with JAX and reading and adapting other’s people code to your needs, but can also be used to show functional programming skills as well as the ability to write performant JAX code. We believe it is a great opportunity to show your skills and convince us we should select you, so we’d recommend sharing anything you do on that end, even if it’s only partial.
Task overview
The one sentence summary of the project is “extending BlackJAX and integrating it seamlessly with PyMC3”. Both projects are in active (very active at the moment) development and are also used by many people both in research and in industry. We may need some specific feature by a given date and would therefore implement it ourselves either before the internship starts of during the internship. The internship work won’t be against any clock, we want internships to be as useful for the community as they are a learning experience for selected interns.
The goals of the project are the following:
- NUTS in BlackJAX
- Basic adaptation in BlackJAX (Stan and PyMC3 like)
- Integrate PyMC3 with BlackJAX
- Use PyMC3 models to both test integration and benchmark sampling (side note: should we also see at some point how many advanced PyMC3 models can’t be samples with JAX after omnnistaging?)
— from here on, items have no particular order —
- Advanced adaptation schemes for NUTS (campfire, covadapt, block mass matrices…)
- ChESS-HMC
- Benchmarks directly in BlackJAX
At the beggining of the internship we’ll define on which of 1-4 work will start with (probably somewhere between 2-3) and the internship will be based on finishing up until 4, with 1-2 extra points from the unordered list based on starting point, general state of the pymc3 project and intern interests.
Assignment
The assignment will consist in using the approach outlined in covadapt
to sample and work with an n
-dimensional multivariate normal given some eigenvectors and eigenvalues. The assignment should be completed using JAX.
The code snipped below can be used to generate the eigenvectors Q
and the eigenvalues ÎŁ
n = 500
Q = np.zeros((n, 2))
Q[:6,0] = [1, 0, -4, 6, 0, 4]
Q[:6,1] = [0, 8, 0, 5, -3, 0]
Q = Q / np.sqrt((Q ** 2).sum(0))
ÎŁ = np.diag([20, 0.01])
How it works
Given some eigenvectors Q and eigenvalues ÎŁ, we can represent a covariance matrix C = I + QÎŁQ^T - QQ^T without storing anything more than those few vectors. The resulting matrix has the given eigenvectors and values, all other eigenvalues are 1.
In order to run NUTS or some other HMC we need matrix vector products C·v and C^{-\tfrac{1}{2}} · v, where C^{-\tfrac{1}{2}} is some factorization of C^{-1}. Thanks to the structure of C we can implement both matrix actions efficiently, without needing to invert the complete covariance matrix, using only Q and \Sigma.
An implementation of the operations required to use a covariance matrix C in the form defined above as mass matrix of a NUTS/HMC sampler is available at covadapt/matrix.py. You can assume others
to be 1.
You should define the methods in Eigvals
as functions that take Q
and ÎŁ
as required inputs, some will also have v
as input and return either a Q, ÎŁ
tuple or an array with the results. The n x n
C matrix should not be used. Moreover, the draw
function will ideally take an n_draws
argument to allow generating multiple draws at once.