The purpose of this blog post is to provide a brief, minimally mathematical introduction to recent work on the edge of stability, layerwise Jacobian alignment, and related phenomena. The intended audience is people who work on neural networks - and hence already know terms like gradient descent, multilayer perceptron, etc. - but are not deep into the mathematical weeds. My aim is to clearly explain what the hell I've been doing for the last three years and why it matters. In this first part I will explain the seminal paper by Cohen et al., and in the second part I will discuss how we extended that work in our own paper. There may be further parts after that, depending on how this goes.
Let me start by setting the stage. I work on neural networks. It bothers me deeply that I do not understand why neural networks really work. Why do neural networks successfully generalize from a limited, finite training dataset to new, previously unseen data? Now, for most people, the fact that this does work is much more important than why. But, as a scientist, this frustrates me, and it scares me. The reason why it frustrates me is obvious. The reason why it scares me is because I value understanding for its own sake - but the people who pay for science mostly do not. They fund science because they want to build a new widget that serves some practical purpose. Which, you know, fair enough. Over the last century, the way to get those new widgets has been (in part) to let people like me figure out how to understand phenomena, then to leverage that understanding. But if they come to believe that they can pour compute and data into a problem, and get a better widget faster for less money, they are not going to keep paying for people like me to do the things that I do.
The deep learning revolution began with problems where the traditional scientific approach hasn't had much luck, like natural language generation. But the empire of the neural network is annexing new provinces every day, and it is already well-entrenched on "traditionally scientific" territory like protein folding. I imagine a future where, instead of physicists building theories out of CERN's output, we use that data to update a neural network, and we don't worry about why particular inputs produce particular outputs. I do not think that future is necessarily certain or even probable, but I think it is closer and more likely than most scientists realize. If neural networks are genuinely more effective ways to solve problems than understanding them, then we are not going to be able to avoid using them. Therefore, if we are going to continue to expand our understanding of the world, we need to figure out how to extract that understanding from the neural networks. That is why I feel a real sense of urgency about solving the problem of why neural networks work.
That, and it's a genuinely fascinating problem.
If we are going to understand why neural networks work, we need to understand their training process. Some theoretical approaches start by just assuming some neural network that has, in some unspecified fashion, learned some dataset, and try to proceed from there. However, we know from experiment that there do exist neural network parameters that memorize a training set but perform badly on new data. (See Zhang et al. 2016). We just don't seem to land on those parameters when we train networks in practice. So understanding the training process itself is a key part of understanding neural network generalization. And the big result of Cohen et al. 2021 is that that training process does not work the way that we think it does.
Before I continue, I want to specify that we are going to exclusively talk about non-stochastic, full-batch gradient descent. The only randomness in this process will be the initialization of the parameters before training begins. I'm also going to only discuss basic gradient descent without momentum or other tricks, although it's important to note that Cohen et al 2021 shows that these results do still apply when using momentum, and followup work shows that it also holds for Adam and other adaptive optimizers and for Sharpness-Aware Minimization.
As a mathematician, I traditionally think of training neural networks as solving a system of ordinary differential equations. The usual metaphor is that we are rolling smoothly down hill. We imagine ourselves on a landscape, where our horizontal position on the landscape corresponds to a particular set of network parameters, and the height is the training loss. We start at some random point on the landscape, and then roll downhill until we reach zero training loss. This metaphor is strained because the "landscape" is not 1- or 2-dimensional but many-millions-dimensional - every parameter in the network corresponds to a possible direction we can move horizontally - but it's still a useful metaphor.
However, one of the ways that the metaphor breaks down is that we can't "see" the landscape as a whole. We only know our current height (the loss) and the direction of the slope at our current position (the gradient). In gradient descent we take a step in the downslope direction, of a pre-specified distance called the learning rate. But the slope of the landscape is itself changing underneath our feet - it's curved, not straight. If we step too far we will overshoot, and end up higher than where we started. If the learning rate is too high, we say training is unstable. If training is unstable, it no longer accurately approximates the system of ordinary differential equations, and we would expect the loss to blow up to infinity. For a given point on the landscape and a given direction, we call the speed at which the landscape is changing under us in that direction the curvature in that direction. And, for a given point on the landscape, we call the maximum curvature for any direction the sharpness.
If we know the sharpness, we can calculate the maximum learning rate at which gradient descent will be guaranteed to be stable. However, in practice we don't do this because it's extremely computationally expensive. The curvature is determined by the eigenvalues of something called the Hessian matrix of the training loss. The sharpness is equal to the largest of those eigenvalues. That matrix has as many rows and columns as the network has parameters, so it's too big to even be materialized in memory. It's still possible to calculate Hessian-vector products using some linear algebra, which is enough to get a few of the biggest eigenvalues, including the sharpness, using a technique called power iteration. However, it's still extremely expensive. So in practice learning rates are determined empirically, by finding the biggest learning rate that still works.
What Jeremy Cohen and his coauthors did is they trained neural networks using gradient descent, measuring the sharpness as they went. They showed that, when we train a neural network, the sharpness steadily rises until the network is just barely overshooting - hence, on the "edge of stability". If we reduce the learning rate, the sharpness rises to compensate! They showed experimentally that this occurs on an incredible variety of network architectures, datasets, and loss functions. As a result, the loss bounces up and down at each iteration, but it still decreases on average over time. This had not been previously noticed because, in practice, nobody uses full-batch gradient descent! Everyone uses stochastic gradient descent (SGD), and if you're training with SGD and your loss bounces up and down, you chalk it up to getting an unlucky mini-batch.
The figure below comes from some networks I trained myself. I used six-layer multi-layer perceptrons with a width of 512 and the exponential linear unit activation, and trained them on a subset of CIFAR-10 using cross-entropy loss. The vertical axis is the sharpness, while the horizontal axis is the "time" in the dynamical system, equal to the learning rate multiplied by the number of iterations. The blue, cyan, and green curves show networks trained using gradient descent with various fixed learning rates, while the purple curve shows a network trained using the exponential Euler method (which I will discuss in the second part, but which accurately approximates the gradient flow). Each solid curve shows the sharpness of the parameters of a particular network over the course of training, while the dashed lines show the maximum sharpness at which training at that learning rate is stable. As I reduce the learning rate, the network reaches higher and higher sharpness, staying just above the point of stability for most of training until the training loss falls close to zero. Notice how, once the sharpness drops below the stability line, the curves suddenly become smooth? Below that line, training is stable. Above that line, training is unstable: the loss is decreasing, but not consistently, and the network is not following the gradient flow.
So, if training is on the edge of stability, why does the loss decrease on average? Remember that the sharpness of a point is only the curvature in the worst direction at that point. What appears to be happening is that the curvature is only rising in a small number of directions. We've known for some time that, during training, there are a small number of "top" directions in which the curvature is high, and a much larger number of "bulk" directions in which the curvature is more moderate (see, e.g., Sagun et al. 2018). The process looks something like the picture below, where the horizontal axis represents the top directions, while the depth axis represents the bulk directions. We're bouncing back and forth in the top directions, but we're still making progress in the bulk directions. This picture doesn't really show the full story - in particular, it doesn't show how the curvature in the horizontal direction increases as we get deeper into the valley - but we'd need at least one more dimension to show the real shape.
This implies that much of what we think we know about the training process is wrong, and much of our empirical results may not mean what we think they mean. For example, there's been a line of thinking that neural networks don't really use most of their parameters, because experiment has shown that the vast majority of the gradient lies in a narrow subspace (see Gur-Ari et al. 2018). That narrow subspace corresponds to the top subspace. It's true that most of the gradient lies there, but it's just jumping back and forth, cancelling itself out, while the actual progress is happening in the bulk direction!
As a mathematician, this is a deeply counterintuitive result. When I first read Cohen et al. 2021, it hit me like a thunderbolt. And it raised an immediate question: what is down at the bottom of that valley? Cohen et al. did do a few experiments where they adjusted the learning rate as they went to keep training stable, so as to solve the exact gradient flow. In our followup work, we borrowed more powerful solvers from the field of stiff ODEs to thoroughly explore this new terrain. In the process, we discovered two things: first, we traced the rise in the sharpness to a phenomenon we named layerwise Jacobian alignment. And second, we showed that this layerwise Jacobian alignment scales with the size of the dataset. I'll explain in detail in the next post.