Numpyro JAX sampling very slow

I guess there are 3 questions here:

  1. Why is the sampling slow (general profiling)
  2. Is the slowness due to the number of inputs, chains, or draws, and if so;
  3. What is the smallest number of inputs/chains/draws you can get away with?

Going backwards:

3. Sample Size

The minimum samples size for Bayesian inference is 0. Of course, in that case you just get back the priors, but those should already be a reasonable model! On the other hand, in the plim, your posterior should converge to a delta function on the true population parameters (er, am I allowed to talk about those here?).

None of that actually answers your question, but it’s a hard question to answer. It’s partially going to be determined by the identification of your parameters, in a strictly mathematical sense. If the gradient of the logp function with respect to a parameter is very very small, you will need a huge amount of data to learn anything about it (or it could be impossible).

As a concrete example, I offer the “Constant Elasticity of Substitution (CES)” production function from economics, defined:

Y_t = A_t \left (\alpha ^ {\frac{1}{\psi}} K_t ^ {\frac{\psi - 1}{\psi}} + (1 - \alpha)^{\frac{1}{\psi}} L_t ^ {\frac{\psi - 1}{\psi}} \right ) ^ {\frac{\psi}{\psi - 1}}

Where Y_t is a firm’s output at time t, K_t and L_t are the inputs of capital and labor, and A_t \sim \exp(N(0, \sigma)) is “total factor productivity”, a measure of the technological/technical know-how of the firm. It has two free parameters: \alpha, the fraction share of capital in the production process, and \psi, the elasticity of substitution between capital and labor (how difficult it would be to switch the process to use less capital and more labor, at the margin).

It is “well known” (citation needed, lmao) that the \psi parameter is not identified. To see this, I made a plot of the logp of this model with respect to both parameters:

image

As you can see, there is an enormous amount of curvature in the \alpha direction, and a (somewhat) visible peak somewhere in the 0.2 to 0.5 range (the true value 0.33 in this case). In the \psi direction, on the other hand, the logp just looks like it’s monotonically decreasing. This is hardly evidence, but the point is that it should be easy for a gradient-based algorithms to quickly and efficiently find zones of high probability for \alpha, while \psi will be much harder.

How does this apply to you? If you have strongly identified parameters (in this local, mathematical sense) you need “less” data. That’s admittedly not much of an answer. It also depends how much “surprise” is in your data. If everything is more or less the same, you need less of it. It there are variations to learn, you need more. One way to get a sense of this is to look at the k-hat statistics for a fitted model. This can give you something kinda sorta like what leverage measures in linear models – which observations are problematic/influential/unlikely, given your model. I also haven’t given you any tools to analyze identification. Pair plots are a good tool (look for pairwise dependency structure in the parameters – anything except random clouds points to weak identification).

Here’s some concrete advice, especially when you are working with complex models and custom likelihoods: always begin with a simulation study. You should take your model and use a draw “true parameters”. Then you can start to study: 1) can you recover those parameters from an associated sample, and 2) what is the effect of increasing the amount of data you generate using those parameters on the accuracy and precision of your estimates. I have often found that when I write a very complex model, I can’t even recover parameters when the model itself is the data generating process – what hope do I have then on real data? In these cases, it often (always?) comes down to identification.

2. Effect of Chains/Draws/Sample Size on Speed

The first easy one is to say that more chains does not affect the sample speed, provided you have enough compute resources to accommodate all the chains. Whether you have the necessary resources can be non-trivial, because each parallel worker can spawn it’s own workers and so on. In general, though, you only ever need 4-6 chains and thus 4-6 cores. Assigning a gazillion cores in pm.sample does not do anything after each chain has it’s own worker.

Next, draws obviously add to the compute time. How many draws you need is going to be influenced by the effective sample size you obtain. The rule of thumb that Andrew Gelman gives is that you want at least 1000 effective samples on everything (I think, but I can’t find a citation). If you’re not getting that, but your chains are mixing without divergences, you need to run the chains longer.

So we come to the effect of the size of the sample itself. This is going to vary with the nature of the operations in your model (and, importantly, in the gradients). Matrix inversion and multiplication are both around \mathcal O(n^{2.5}), so these scale poorly as you add lots of data. (There might be gains to be had from switching to GPU if linear algebra operations are your bottleneck?) More elementary operations are all close to linear, so those scale a bit better. Remember that every single sample you draw requires many, many gradient operations, because an entire Hamiltonian simulation needs to be done to obtain each draw. This is why approximate methods, like ADVI, are preferred in practice for huge datasets, potentially with minibatching.

1. Why is my model slow (profiling)

This is a big topic, but pytensor has some tools to help you try to find the bottlenecks in your model. The basics are here, and I have an example here that was written before the fork with aesara, but you can just find-replace aesara with pytensor and it should work. Basically, you can:

  1. Enable profiling by setting pytensor.config.profile= True at the top of your code.
  2. Write your PyMC model, without sampling
  3. Compile the logp and dlogp functions of your model using f_logp = model.compile_logp() and f_dlogp = model.compile_dlogp(). This makes compiled functions that will evaluate the log-likelihood/gradients at sample points.
  4. run %timeit on both functions, using arbitrary inputs. This is important, because profile needs the function to be run a bunch of times in order to collect data on the performance of each Op that comprises your logp (and d_logp) graphs
  5. Get the profile results with f_logp.f.profile.summary()

The results will be a big readout of Ops and their associated times. It will not show you the profile for inner functions, i.e. scans – you can see in my notebook that it just says aesara.scan.op.Scan, but it doesn’t show you what is going on inside the scan. If you have scans there are some extra steps you have to go through (but IIRC you don’t). Using this method you can see which operations are taking the most time, and try to target them for improvement. In my linked notebook, after I looked inside the scans, I saw that calling pt.eye inside the scan was doing an expensive re-allocation at every step of the loop, so I re-wrote the function to use a single global identity matrix, speeding up the computation.

JAX itself also has profiling tools (see here), but I’ve 1) never used them, and 2) I don’t know how useful they are in the context of generated code (@ricardoV94 might know more?)

Miscellaneous Thoughts

I made this header plural but I only have one:

  1. Try Nutpie. In principle, it should be faster than JAX on CPU (I think someone told me once that JAX is more optimized to target GPUs, but there isn’t much speedup to be had by running PyMC models on GPU. Though if your data is large enough, maybe there is?)
4 Likes