Survival analysis example with Nutpie – odd results

I have been working through the Survival Analysis worked example notebook (the version updated for pymc5). It works as expected as downloaded, but if I change the sampling to use Nutpie then the calculation of beta is (to me) worryingly different:

Pymc sampler:

Nutpie sampler:

The only difference in the code is that that ,nuts_sampler="nutpie" was added to the idata = pm.sample(… block

The original code reports about the same number of divergences as the published version whereas the Nutpie run generates no divergences

I’d expected some minor differences in the values reported but this looks more than that - am I over thinking it?

System summary:

MacOS (26.1), M2 processor

Python 3.13 (conda-forge)

pymc       : 5.26.1
pytensor   : 2.35.1
statsmodels: 0.14.5
nutpie     : 0.16.3
matplotlib : 3.10.8
numpy      : 2.3.5
pandas     : 2.3.3
arviz      : 0.22.0

For what it’s worth I’ve noticed this with nutpie too, it will just completely shit the bed on certain models. I haven’t noticed any specific patterns to when it happens.

1 Like

I think there is an error in the model in the notebook. The gamma distribution should probably have mu=0.01 and sigma=0.01, not alpha=0.01 and beta=0.01. Both samplers diverge because of this.

with pm.Model(coords=coords) as model:
    lambda0 = pm.Gamma("lambda0", mu=0.01, sigma=0.01, dims="intervals")

I wonder if some version changed which were the defaults if no keywords are specified to pm.Gamma? That would be an unfortunate backwards incompatible change. (@ricardoV94 if you have any ideas?)

@jessegrabowski If you know models, where the default sampler converges, and nutpie doesn’t, it would be great if you could report those!

If we changed the defaults it was accidentally but I don’t think we did

The target_accept/ divergences and the warning:

home/cfonnesbeck/GitHub/pymc/pymc/logprob/joint_logprob.py:167: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: [normal_rv{0, (0, 0), floatX, False}.0, normal_rv{0, (0, 0), floatX, False}.out]

All suggest the model wasn’t sampling well in the first place, ao nutpie failing differently isn’t a shocker.

Do you have examples you could share? Which preconditioner was causing problems (diagonal, low-rank plus diagonal, normalizing flow)?

I’m very concerned about this as I’m busy coding a modified version of Nutpie (continuous rather than blocked, and dynamic adaptive step size, and maybe also isokinetic rather than standard momenta), which is likely to be our recommended default sampler for Stan. As such, it’d be really useful to have challenging test cases before we launch.

For what it is worth, I have been experimenting with using nutpie for the type of phylogenetic models written in Stan that my colleague posted about on the Stan forums a while ago. I have found that for some datasets, it does great (faster than default Stan sampler, good draws), but others it has terrible ESS compared to Stan and/or divergences, esp. with low_rank_modified_mass_matrix=True. I don’t yet have a good sense of what are the differences in the data that make a difference.

1 Like

More data usually means a more concentrated posterior, which means higher curvature (larger Eigenvalues of the negative Hessian of the log density). If this curvature varies around the posterior, a single mass matrix isn’t going to be ideal. What we really care about is the condition at every point in the posterior after preconditioning with the mass matrix (roughly the ratio of largest to smallest eigenvalue, but that only works when everything is positive definite—things like hierarchical priors can induce funnel-like behavior that is not positive definite as can multiplicative non-identifiability). Sometimes you can reparameterize, but if not, pulling back to just using a diagonal mass matrix might be better. Stan doesn’t have the low-rank plus diagonal option from Nutpie. The diagonal preconditioner in Nutpie should behave a bit better than Stan’s—if you find that different, I’d be very curious to see an example.

I’d love to see one of those failing models! :wink:

A couple of (obviously generic) ideas about how that might happen:

  • If you set a target acceptance rate of 0.8 in stan, you usually end up with a step size that leads to an acceptance rate of ~0.9. This is because the last step size adaptation window in stan was chosen a little bit too small. So if you have a model that works well with 0.9, but not with 0.8, it might work with stan, but not nutpie, because nutpie usually gives you the acceptance rate you ask for.
  • I think there are combinations of eigenvalues of the posterior covariance where the nutpie diagonal mass matrix is simply worse than the stan one. I’ve not managed to figure out theoretically which ones those are precisely, but a specific model might just happen to have a covariance structure like that.
  • I’ve seen one model so far where nutpie was worse and I think neither of those were the problem. I think that was a case where NUTS chose a bad tree size in the nutpie case, but a good one in the stan case. If I increased the tree size manually (with the mindepth argument), nutpie worked better. I’m not sure why the No-U-Turn criterion would have done that, maybe it was just bad luck? The tree size changes in powers of two, so if you are close to a change-point, the efficiency can change suddenly for even relatively small sampler changes? But that’s really just a guess.
1 Like

It’s tricky to even define what “acceptance rate” means for NUTS, because it’s not a simple Metropolis algorithm. Stan’s target acceptance rate is the average acceptance rate of points in the NUTS trajectory (maybe only in the last doubling—I’m not entirely sure) if they were to be taken as Metropolis proposals. But NUTS doesn’t perform Metropolis adjustment. Instead, it chooses among the points in the trajectory according to a categorical distribution over all of the points including the initial point. Points further from the starting point are overweighted in the biased-progressive form of NUTS implemented in Stan. This is all described, albeit very tersely, in the original Hoffman and Gelman paper in JMLR. We (Nawaf Bou-Rabee, Tore Kleppe, Sifan Liu, and I) unpack it a bit more explicitly when reducing it to an instance of Gibbs self tuning (GIST) in recent papers

If we target a 0.8 Metropolis accept rate, the chance a transition will not choose the starting point is much higher. NUTS could have used the probability that it jumps away from the starting point by tuning on 1 minus the probability of remaining at the starting position. That would not have been ideal, though, because error accumulates if the step size isn’t stable and we really want to control acceptance at the ends of the trajectory more than at the beginning.

P.S. I’ve recently been playing with Adam in place of dual averaging for Nutpie and the results are promising. Both are stochastic gradient algorithms, so they still bounce around a bit over realistic numbers of warmup iterations, even with Robbins-Monro style step size reduction. A less theoretically sound but more aggressive step size reduction may be in order, because both of these algorithms are very fast to converge with Nutpie’s ability to so quickly lock in the mass matrix.

The mass matrix is used as a metric to rescale the dimensions for a U-turn condition. With different mass matrices, the U-turn condition is different. NUTS uses those powers of 2 to make s sure it only does linearly many U-turn checks rather than quadratically many. This can lead to internal U-turns that don’t align with power-of-2 boundaries to be missed.

There is also the problem that no fixed mass matrix can precondition a multiscale posterior (which is pretty much everything other than uniform and normal, I think). I have a feeling multiscale posteriors are where a lot of the problems are going to arise and I don’t think anyone knows what even the best target is here. Something like average condition number in log concave cases might be reasonable—that’s kind of what Nutpie’s trying to target with Fisher divergence anyway. But I suppose we could design a divergence to target it directly.

love the choice of language ahah