Implementation of generalized linear mixed in the form X\beta + Zb

Hi there,

I am using JAX (for my PhD research, not GSoC) to implement a modified variational approximation for models such as generalized linear mixed-effects, non-linear mixed effects models, and certain types of state-space models in distributed settings.
The only difference in the implementation of the algorithm across these different model types is the log probability function (and its derivative).
I thought it could be exciting and valuable further down the line to try to write the models in PyMC, take the Aesara graph given by the model, and convert it to a JAX compiled function, thus giving me the required log probability (and the derivative) in JAX.
My first question is have I interpreted how a PyMC model, the Aesara graph, and JAX interact correctly?

Suppose that we consider only GLMMs from now on. I want to implement a PyMC model for GLMMs in a generic form, which means that if I was to do 5 different case studies involving GLMMs, each with different random effect structures, I could use the same PyMC model.
In my head, this means something like

\textbf{y}\sim\text{ExponentialFamily}\big (g^{-1}(\eta_i)\big ),
g(\eta_i)=\textbf{X}\mathbf{\beta} + \textbf{Z}\textbf{b},
followed by the necessary priors for the unknown parameters.

In this form, it means that each GLMM could be fit using the same PyMC model, but with different \textbf{y}, \textbf{X} and \textbf{b} as input.
My next question is whether there is an example of a generalized linear mixed effects model implemented in this form in PyMC?

I couldn’t find an existing implementation anywhere. Maybe it could be an interesting example notebook I could work on if it does not exist?

P.S. I am not knowledgeable on any topic, so if anything above sounds seriously misguided, please let me know.

Thanks!

Hi!

In Bambi we try to do something like that. Have a look at the following notebooks. These belong to a repo with unfinished ideas or examples.

unfinished-ideas/sparse.ipynb at main · bambinos/unfinished-ideas (github.com)
unfinished-ideas/sparse_2.ipynb at main · bambinos/unfinished-ideas (github.com)

Notice we used theano.sparse.structured_dot for Zb. We could have used a regular dot product, but Z is a very sparse matrix.

The question is whether JAX implements sparse matrix multiplications?

1 Like

Hi Tomas,

Thanks for your answer; it is interesting!

As with all good research questions, I changed my opinion and don’t think I need an interface with JAX.

As long as I can input \textbf{X}, \textbf{y}, and \textbf{Z}, and interact with a log probability function \log p(\mathbf{\theta}, \textbf{y}), and it’s derivative \frac{d}{d\theta}\log p(\theta, \textbf{y}), regardless of backend (e.g., Aesara, Theano, or JAX), then it will suit my needs.

I am unsure whether Bambi can do this, but I will dig deeper and let you know! Thanks again.

If you inspect a Bambi model, you will see that there’s an attribute called ._design. This is a design object that contains design matrices for the response (y), the common effects (X) and the group-specific effects (Z). This happens here.

You could either construct a PyMC model using y, X, and Z, or use the PyMC model constructed by Bambi. You can access it with bmb_model.backend.model.

1 Like