In the first part of this series, I explained that "training on the edge of stability" means that the sharpness of the loss landscape rises over the course of training until gradient descent no longer approximates the gradient flow. In this part, I will discuss where the increase in the sharpness comes from. If you haven't read the first part yet, go back and do that now, because I'm going to assume you already know what terms like "sharpness" and "stable" mean.
In our paper, we implemented a more powerful solver so we could train neural networks by following the gradient flow exactly, without entering the edge of stability. Cohen et al. 2021 already did something similar, by adjusting the learning rate as they trained to keep training stable. This works, but the compute required scales roughly cubically with training dataset size, and we wanted to run a lot of experiments. The specific approach we used is called an exponential Euler solver. The exact mechanics are complicated, but it allows us to compensate for the curvature in the highly-curved directions, while taking large steps in the non-curved directions. It's still too expensive to use to train neural networks on big datasets like ImageNet, but we can use it to train on large subsets of CIFAR-10. You can find my current implementation of the solver here.
Second, we mathematically decomposed the curvature into components, and measured each component over the course of training. The details get long and intensely mathematical, and you can find them in the paper, so I'm just going to focus on the upshot. First, let's take a look at what happens to the sharpness if we train the network without entering the edge of stability. In the figure below, I trained a collection of neural networks on subsets of CIFAR-10 of various sizes. The horizontal axis is the time, in the dynamical system sense, while the vertical axis is the sharpness. The color of the curve corresponds to the size of the subset. Each network is a 6-layer multilayer perceptron with a width of 512 and exponential linear unit activation functions.
What we see is an initial transient of high sharpness, followed by a very rapid decline to a low level, followed by a steady rise over time, until the sharpness finally plateaus and begins to drop at the end of training. And the peak sharpness increases as the size of the dataset increases.
the initial transient is still a mystery - it's actually the focus of my current research - and it does not come from the same place as the increase in sharpness over the course of the training. However, it also only lasts 1-10 iterations, so whatever it's doing doesn't affect most of the training. We do know the cause of the drop-off at the end: Cohen et al. showed that it's an artifact of the cross-entropy criterion. The exact details are a bit mathematical, but with other criteria like mean-squared error the sharpness just keeps rising till the end of training, so I don't think it's worth getting into. What we're here to talk about today is the steady, consistent rise between the initial transient and the final dropoff. The component of the sharpness that causes this increase is something we called layerwise Jacobian alignment.
The layerwise Jacobians are matrices that measure how a change in the output of one layer affects the output of the next layer. The layerwise Jacobians depend on both the network parameters and the network input - the effect of a change at layer 1 on layer 2 depends on what the input to layer 1 was. They're also what's used to backpropagate the gradient from the network outputs to the parameters. For each layer, we get the gradient of the parameters from the gradients of the layer outputs, and the gradient of the outputs is found by multiplying the layerwise Jacobian by the next layer's gradient. So they also control how a gradient signal at the network outputs turns into gradients at each of the layers.
So, suppose we have some of these layerwise Jacobians, and we have a gradient at the output of the network, that is, layer 6. The magnitude of the gradient at layer 5 depends on both the magnitude and direction of the gradient at layer 6. For some directions, the layerwise Jacobians will amplify the gradient, while for other directions, they will dampen it. But the gradient at layer 5 then goes on to determine the gradient at layer 4, and again, the gradient at layer 4 will depend on both the magnitude and the direction at layer 5, which depends on both the magnitude and direction at layer 6. And then the gradient at layer 3 propagates on further to layer 2, and so on.
Let's consider my standard 6-layer multilayer perceptron using a width of 512 and the exponential linear unit activation. We initialize the weights with the standard Kaiming initialization, then pick an input from CIFAR-10 at random, giving us a set of layerwise Jacobians. Then I calculate the gradient signal at the network output that will have the largest corresponding gradient at layer 1. As we can see, the gradient is slowly amplified, reaching a magnitude of 2.47:
As we backpropagate the gradient from the end of the network to the front, parts of the gradient are amplified and parts are dampened, but on average the effects roughly cancel each other out. Now suppose we train that network on a subset of CIFAR-10 of size 128. As the network parameters change, the layerwise Jacobians change. Now when we calculate that optimal gradient, amplification has gotten much stronger:
Now, when we backpropagate that gradient, the gradient is consistently amplified at each layer, getting bigger and bigger as we go down to the network inputs. Critically, this isn't because the layerwise Jacobians are changing so that they amplify any gradient. A random gradient will still, on average, maintain roughly constant magnitude throughout the network. Instead, the layerwise Jacobians are rotating, so that they amplify a specific gradient. When we first initialize the network, for every layer, there is some direction so that a gradient at layer 6 in that direction will be amplified at layer 5. But the gradient at layer 5 won't line up with the most-amplified direction at layer 4, then 3, etc. But after training, the amplified direction at layer 5 produces a gradient at layer 4 that falls on the amplified direction at layer 4, which produces a gradient at layer 3 that is amplified in turn, and so on.
Here's a graph showing how layerwise Jacobian alignment increases over the course of training:
Now let's rewind back to the process of actual network training and think about how this affects it. Once the layerwise Jacobians start to become aligned, a small change in the weights and biases of the first layer can cause a small change in the preactivations of the first layer, which causes a large change in the preactivations of the last layer, which causes a large change in the network outputs, which overshoots the change we were trying to make! That's what causes the sharpness to rise. If we could magically keep this alignment from occurring, neural network training would be stable with standard learning rates.
So that's layerwise Jacobian alignment. I hope it was at least somewhat comprehensible. Further installments will follow as new results, my own and others', seem worth discussing.