Hi there,
I am using JAX (for my PhD research, not GSoC) to implement a modified variational approximation for models such as generalized linear mixed-effects, non-linear mixed effects models, and certain types of state-space models in distributed settings.
The only difference in the implementation of the algorithm across these different model types is the log probability function (and its derivative).
I thought it could be exciting and valuable further down the line to try to write the models in PyMC, take the Aesara graph given by the model, and convert it to a JAX compiled function, thus giving me the required log probability (and the derivative) in JAX.
My first question is have I interpreted how a PyMC model, the Aesara graph, and JAX interact correctly?
Suppose that we consider only GLMMs from now on. I want to implement a PyMC model for GLMMs in a generic form, which means that if I was to do 5 different case studies involving GLMMs, each with different random effect structures, I could use the same PyMC model.
In my head, this means something like
\textbf{y}\sim\text{ExponentialFamily}\big (g^{-1}(\eta_i)\big ),
g(\eta_i)=\textbf{X}\mathbf{\beta} + \textbf{Z}\textbf{b},
followed by the necessary priors for the unknown parameters.
In this form, it means that each GLMM could be fit using the same PyMC model, but with different \textbf{y}, \textbf{X} and \textbf{b} as input.
My next question is whether there is an example of a generalized linear mixed effects model implemented in this form in PyMC?
I couldn’t find an existing implementation anywhere. Maybe it could be an interesting example notebook I could work on if it does not exist?
P.S. I am not knowledgeable on any topic, so if anything above sounds seriously misguided, please let me know.
Thanks!