I’m really curious what you find enjoyable about it. Python? Generative model automation for things like posterior predictive chekcs and simulation-based calibration? Having a model-building API like PyMC rather than a domain-specific language like Stan? I’ve been much happier in Python than in R—plotnine is a life saver.
The reason I ask is not because I think we’ll be able to compete with Stan, but because I am working on building a replacement ecosystem for Stan directly in JAX. Coding a model in JAX is more like coding in Stan than PyMC, because you are just leaning on JAX to give you a differentiable language for specifying a log density. But unlike Stan, it’s designed natively for flexible GPU coding and parallelism and has the constructs to support that (despite their “sharp edges”).
In contrast, PyMC really wants you to build a graphical model. I had originally thought we could just use PyMC or NumPyro to generate JAX, but those are very different paradigms and the graphical modeling focus makes it impossible to translate Stan models line for the line the way that is possible in JAX.
I agree, plotnine is awesome! Knowing I wouldn’t completely leave behind ggplot2 went a long way to convince me to use Python more – and Python is certainly what drew me in to PyMC.
I don’t have formal advanced math training. Having been introduced to Bayesian models through McElreath’s book, and then further indoctrinated by Hobbs and Hooten and Gelman et al’s BDA (admittedly a much more challenging read for me), I find translating models from the McElreath style math-stats I write on my chalk board into Stan a simple workflow. Stan just plain rocks for the models I fit. However, I think it is mostly Python per se and the generative model automation for ppc’s that I’ve enjoyed about PyMC thus far. Additionally, in PyMC, the translation from math-stats to model building context manager that PyMC uses (at least the simple-ish ones) is likewise fairly straight forward even if I still prefer to work in Stan.
I feel pretty similarly about turing.jl, for what it’s worth!
Modeling in JAX has not been on my radar but sounds like it would be a hit! Especially given the GPU flexiblity…
Thanks for the response. What we’ve found with Stan is that the Stan language itself and formulating models is beyond most applications-oriented scientists who just want to use Bayes to fit a model. Part of it is writing loops, part of it’s writing types, and part of it’s just trying to convert stats ideas to code. That’s why tools like brms are an order of magnitude more popular than Stan itself. I think Python attracts a more programming-oriented community than R, so formulating models with the PyMC API may be less of a problem for the PyMC users than it is to formulate models in Stan for the Stan users.
I definitely notice this in my circles as well. Many of my colleagues can fit all of the models they need to in rstanarm which is a fantastic tool. However, I think not learning to translate math to code as in Stan or PyMC obfuscates the power/flexibility of PPL and Bayes, constraining the types of questions applications-oriented scientists think to ask and answer.
From the Hobbs and Hooten book:
“If the only tool in our locker is analysis of variance, then the world we study must be composed of randomized plots.”
I wouldn’t have realized how true that is before visiting epidemiologist John Carlin in Sydney (he’s one of the co-authors of Bayesian Data Analysis). I asked him what kinds of models his students were working on. He just laughed and told me that epidemiologists mostly just learn how to interpret anova output and he had to work with people in the stats department to work on new models (which he does a lot of, by the way).
I came to stats from math and computer science degrees, so coding the models has always been the easiest part of this whole endeavor for me.
The general picture is correct but the comparison is a a bit off. PyMC and numpyro are specific PPLs built on top of general python computational packages: PyTensor and JAX, respectively.
These computational packages exist to allow you to use python to create efficient non-python code for array based computations. They also give you autodiff and easy GPU integration.
They are generic so of course it’s easier to implement something custom on top of them. If you wanted to implement Stan in Python you would probably go for something like that and not build on top of an existing PPL.
The relevant question is then whether you want to implement that in JAX or PyTensor (or numba or torch or whatever)[1].
We use Pytensor for historical reasons, since PyMC was built on top of Theano, which predated all these popular libraries that run the gamut today.
However we came to like it for its own terms. It allows us to stay in python land a bit longer to do things like graph manipulation, algebraic rewrites, numerical optimizations that compilers don’t dare to touch.
Those things in other packages have to be done at a lower level or require one to create a translation layer on top of the package (like manipulating jaxexprs).
Because we stay in python land longer we also found an easy way to end up in jax/numba/ and now pytorch in the end without compromising much (the abstraction always leaks in the end ofc).
OTOH user can try a couple of backends, thay may fit their plate better.
[1] @aseyboldt would probably suggest you use JAX to get model densities and gradients, but write your own sampler in something that gives you more control over CUDA streams.
The other relevant question is what you want to do different than PyMC/numpyro/Stan. And does it require a complete new PPL?
nutpie for example is pushing some normalizing flow-related automatic reparametrizations that would be harder in a Stan-like language where the core object is a sum of densities and not a graphical model.
Adrian’s draft paper hints that he couldn’t get the black-box version working past 1K dimensions but hopes that using model structure somehow will help. Depending on what the input is around model structure, it may or may not be possible with Stan. It’ll certainly be easier to use model structure automatically in PyMC.
I get the part about (embedded) PPLs, but I thought PyMC was more flexible about output format. Thomas Wiecki sent me a notebook producing JAX and some examples under pymc_experimental using Blackjax, but the latter are now 404. Is that just embedding into JAX rather than translating into JAX? We can embed Stan models in JAX that way, but it doesn’t let us jit them or run them on GPUs.
If JAX had existed when we started (2010), we probably would have used that. The big issue is R integration—the stats community is still largely weddd to R. I, on the other hand, have moved to Python.
My point is really that with something like JAX, you don’t need a PPL to write models the way we write them in Stan. Stan’s really just a differentiable programming language with some convenient stats abstractions and a simple syntax. ll Stan’s doing anyway other than giving you a way to structure code and compute posterior predictive quantities in the same languag eas the model. But with JAX, we can just vmap the output, so we don’t even need that block. Similarly, we can just define functions in JAX, so we don’t need that block. And pytree makes it very easy to structure arguments. Oryx is now a pretty extensive transform library and JAX and many additional packages have a pretty big special function library. One thing I haven’t looked into is all the solvers we use in Stan, like algebraic equation solvers, root finders, 1D integrators, ODE integrators, etc.
We do this for Stan, too. In using JAX straight up, that’s all on the coder and the JAX developers, which is going to be rough on the statisticians. I’m not sure we ever got a lot out of optimizations in Stan, but then optimizations should be easier if there’s a graphical model base.
I understood your dissatisfaction was about the input, that you can’t write a PyMC model the way you would write a STAN model, line by line. Not about the output.
PyMC is perfectly capable of outputting to you the model conditional/joint densities in jax code (which you called translated into JAX?) which you can then transform/vmap/pmap as you please. That’s how we interop with the jax samplers (ot rust via the numba backend)
Regarding the 404 we recently renamed the package from pymc-experimental to pymc-extras, so try changing the link. If you have the links I can try to dig them out as well
Sounds like you guys could have used Theano all the way back in 2010 for the reasons you mention JAX today. vmap and pytress are novel but wouldn’t have been the blockers? I feel R is a harder barrier to cross.
Hi, @ricardoV94. Is that Google Colab notebook you shared something you’re willing to make public? That was super useful.
I’m not dissatisfied with anything! I just think that with the goal of using Stan model coding style, it’ll be easier to just do that straight in JAX rather than going through PyMC. The Colab notebook clears up exactly how you can code models like Stan, which is super helpful.
The Wikipedia page just says this, and I still can’t tell what Theano does.
Theano is a Python library and optimizing compiler for manipulating and evaluating mathematical expressions, especially matrix-valued ones.[2] In Theano, computations are expressed using a NumPy-esque syntax and compiled to run efficiently on either CPU or GPU architectures.
It doesn’t even mention derivatives! In the past, I’ve seen theano described as a code transformation tool, but I think it may be more autodiff-based? The combination of not being able to even understand the top-level descriptions and it being in Python were dealbreakers for us.
I had no idea you could do that with Discourse. Thanks!
P.S. The Theano Wikipedia page is a mess. It mentions that Theano hit end of life and was forked by @fonnesbeck as Aeasara (no Wikipedia page) and then by the PyMC devs as PyTensor (also no Wikipedia page). I assume that means nobody is still maintaining Theano, even though the Theano page lists the PyMC devs. We struggled with getting a Wikipedia page for Stan, so I’m not surprised there’s not an Aesara or PyMC page, despite both projects being actively maintained.
You can think of pytensor/theano==jax, a library that gives you autodiff and compiled code for array based computations.
Theano had an emphasis on the graph of the computation, which you could manipulate easily from Python, whereas other frameworks emphasize it less / hide it.
Theano influenced a bunch of these frameworks after which it was deemed “unnecessary” from a research standpoint.
When the original devs abandoned Theano, the pymc-devs forked into Aesara, then due to project-related conflicts it was again forked as PyTensor, still by the pymc-devs.
I quite like this description from our docs (which are admittedly impossible to navigate)
PyTensor combines aspects of a computer algebra system (CAS) with aspects of an optimizing compiler. It can also generate customized code for multiple compiled languages and/or their Python-based interfaces, such as C, Numba, and JAX. This combination of CAS features with optimizing compilation and transpilation is particularly useful for tasks in which complicated mathematical expressions are evaluated repeatedly and evaluation speed is critical. For situations where many different expressions are each evaluated once, PyTensor can minimize the amount of compilation and analysis overhead, but still provide symbolic features such as automatic differentiation.
I personally think of Pytensor as a mix of sympy and <insert your favorite DL library here>
Largely performance and design flexibility. See the evaluations that start on about page 67 of our autodiff paper.
The design that we extended is Sacado (from the COIN-OR project). We cite the David Gay paper that explains how Sacado’s autodiff works. We just had to engineer a more efficient version. The evaluations compare Stan’s autodiff to CppAD, Sacado, Adept, and adolc.
I go over a basic C++ autodiff design from scratch based on functional C++ with continuations and then refactor the lambda-based closure design to the custom version used in Stan:
Is there doc explaining how this all works in more detail? This is the same kind of language about transpilation, computer algebra, and autodiff that confused us about what Theano actually did. We and a number of other people thought it was a code transformation tool like some of the existing Fortran libraries at the time.