Sampling speeding up

I often notice that posterior sampling is initially slow, then speeds up significantly after a few hundred or so samples in the warmup of chains.

Why does that happen?

I have no problems but thought I would ask. Seems like a good way to learn more about the sampling process.

Hey, I was interested in this same thing a couple of months ago. During warmup, it is adjusting the number of steps and the step size. Taking a lot of short steps will cause the progress bar to move slowly. So the tuning process is trying to find a way to take a small number of very long steps. At some point, it uncovers a good balance and there is a massive jump in efficiency. You can explore the tuning process in detail by setting discard_tuned_samples = False.

x = stats.norm(0,1).rvs(100)

with pm.Model():
    mu = pm.Normal('mu',0,1)
    y = pm.Normal('y',mu,1,observed=x)
    trace = pm.sample(chains=1,discard_tuned_samples = False)

Then if look at the step size with plt.plot(trace.warmup_sample_stats.step_size[0])

image

And the number of steps with plt.plot(trace.warmup_sample_stats.n_steps[0]) you can see the jump.

image

3 Likes

Ah ok, I see. So it’s really as simple as more steps per sample in the warmup.

1 Like

Hey! I’ve done a lot of work on the tuning side for these algorithms, and you’ve got it mostly exactly right. NUTS will keep on doubling in length until it hits a certain dynamic criterion (a U-turn), and then do some book keeping to maintain detailed balance. Part of the U-turn criterion is “is this the 10th doubling?” (I think 8th doubling during tuning). so @daniel-saunders-phil spotted this well – note that the number of steps is always 2^n - 1 (I see a 31, two 15’s, a bunch of 7’s, and so on), corresponding to the number of doublings.

PyMC will constantly update the step size during tuning to try to get an acceptance rate of 80% (using stochastic optimization!). It also updates the mass matrix, which is by default diagonal, and every 101st step. This is why in @daniel-saunders-phil’s first plot, there appears to be a discontinuity at step 101: the mass matrix adaptation kicks in, and then it is updated live after that with a windowed approach.

Tuning in PyMC3 is a talk from a few years ago on how PyMC and Stan do tuning. I’m not sure if they’re both true any more! There are a bunch of gifs on the page that don’t autoplay anymore, but do spark joy: I had to right click: “Show Controls”.

7 Likes

There are other adaptation strategies, for instance if you pass init="jitter+adapt_diag_grad" you’ll get an algorithm that updates the mass matrix continuously and shouldn’t have such discrete jumps

2 Likes

that post is really great @colcarroll

1 Like