Stable Solvers¶
Stable Solvers provides solvers for training neural networks without entering the edge of stability, for use in investigating neural network loss landscapes. As discovered in Cohen et al. 2021, during neural network training, the curvature of the Hessian matrix rises until the training is overshooting. Stable Solvers provides adaptive solvers that avoid this overshooting and follow the true gradient flow, to support scientific research into the loss landscape. You can find Stable Solvers here, or install it using pip:
pip install stable-solvers
The library currently provides two solvers: adaptive gradient descent and the exponential Euler solver. Adaptive gradient descent calculates the curvature of the Hessian matrix at every iteration and adjusts the learning rate to prevent overshooting. The exponential Euler solver exploits our knowledge of the quadratic terms of the loss function to take larger steps. Each step of the adaptive gradient method is cheaper, but it has to take more of them. Which is optimal depends on the problem to be solved: the exponential Euler method is generally best if the number of network outputs is small or the dataset size is very large, while the adaptive gradient method is best if neither of those conditions apply.
Since both solvers require calculating eigenvalues of the Hessian matrix, they need direct access to the network, dataset, and criterion. We package these into a class called a LossFunction
. The expected syntax is:
import stable_solvers as stable
net = ...
dataset = ...
criterion = ...
loss_func = stable.LossFunction(
dataset=dataset,
criterion=criterion,
net=net,
...
)
params = loss_func.initialize_parameters()
solver = stable.ExponentialEulerSolver(
params=params,
loss=loss_func,
max_step_size=0.01,
stiff_dim=...,
)
loss = float('inf')
while loss > 0.1:
loss = solver.step().loss
A full example can be found in this notebook.
- class stable_solvers.LossFunction(dataset: Dataset, criterion: Callable, net: Module, **dataloader_kwargs)¶
This class encapsulates a dataset, criterion, and network architecture, and takes as input a set of network parameters. It is intended to allow easy encapsulation of these objects, to make it easier to pass them with their associated dataloader properties to various functions, as well as providing various utility methods.
The expected signature of the dataset, criterion, and networks are:
The expected signature to use
LossFunction
is:- __init__(dataset: Dataset, criterion: Callable, net: Module, **dataloader_kwargs)¶
- Args:
dataset (
torch.utils.data.Dataset
): The dataset.criterion (callable): The criterion.
net (
torch.nn.Module
): The network architecture.
In addition, keyword arguments such as num_workers and batch_size may be passed that will be used in instantiating the
torch.utils.data.DataLoader
.
- forward(params: Tensor) Tensor ¶
Calculates the loss for a given choice of parameters.
- gradient(params: Tensor) Tensor ¶
Calculates the gradient of the network, and returns both it and the value of the loss.
- initialize_parameters(gain: float = 1.4142135623730951, device: device = device(type='cuda'), dtype: dtype = torch.float32) Tensor ¶
Convenience function to create a suitable
torch.Tensor
to use as parameters. Uses the Kaiming initialization for weights, and zeros for biases.- Args:
gain (float): The gain for the activation function. device (
torch.device
): The device to create the tensor on. dtype (torch.dtype
): The dtype of tensor to create.
Solvers¶
- class stable_solvers.GradientDescent(params: Tensor, loss: LossFunction, lr: float)¶
Performs conventional gradient descent without momentum, following the update rule:
\[\theta_{u+1} = \theta_u - \eta \nabla_\theta \widetilde{\mathcal{L}}(\theta)\]Where \(\theta_u\) is the parameters at iteration \(u\), \(\eta\) is the learning rate, and \(\widetilde{\mathcal{L}}(\theta)\) is the training loss.
This class is primarily provided for comparison purposes, but training with conventional gradient descent can be stable if the learning rate is small enough.
- __init__(params: Tensor, loss: LossFunction, lr: float)¶
- Args:
params (
torch.Tensor
): The parameters of the network that are being optimized.loss (
LossFunction
): The loss function.lr (float): Learning rate.
- class stable_solvers.AdaptiveGradientDescent(params: Tensor, loss: LossFunction, lr: float, warmup_iters: int = 0, warmup_factor: float = 1.0)¶
Performs gradient descent without momentum, adapting the learning rate at every step to prevent entering the edge of stability:
\[ \begin{align}\begin{aligned}\theta_{u+1} = \theta_u - \eta_u \nabla_\theta \widetilde{\mathcal{L}}(\theta)\\\eta_u = \min\left(\eta_{\max}, \frac{1} {\lambda^1(\mathcal{H}_\theta \widetilde{\mathcal{L}}(\theta_u))} \right)\end{aligned}\end{align} \]Where \(\theta_u\) is the parameters at iteration \(u\), \(\widetilde{\mathcal{L}}(\theta)\) is the training loss, \(\mathcal{H}_\theta \widetilde{\mathcal{L}}(\theta)\) is the Hessian matrix of the training loss, \(\lambda^1(\cdot)\) is the top eigenvalue, and \(\eta_{\max}\) is a hyperparameter.
- __init__(params: Tensor, loss: LossFunction, lr: float, warmup_iters: int = 0, warmup_factor: float = 1.0)¶
- Args:
params (
torch.Tensor
): The parameters of the network that are being optimized.loss (
LossFunction
): The loss function.lr (float): Maximum learning rate. If the adaptive learning rate exceeds this value, it is truncated to be no higher than this.
warmup_iters (int): If set, the maximum learning rate is initially set to a lower value for this many iterations, to damp out initial transients.
warmup_factor (float): If set, the maximum learning rate is initially reduced by this factor, to damp out initial transients.
- step() AdaptiveGradientDescentReport ¶
Takes a single step, returning a report.
- class stable_solvers.ExponentialEulerSolver(params: Tensor, loss: LossFunction, max_step_size: float, stiff_dim: int, warmup_iters: int = 0, warmup_factor: float = 1.0)¶
Uses the exponential Euler method from Lowell and Kastner 2024:
\[ \begin{align}\begin{aligned}\theta_{u+1} = \theta_u - \sum_{m=1}^k c_u^m r_u^m v^m(\mathcal{H}_\theta \widetilde{\mathcal{L}}(\theta)) - \eta_u w_u\\r_u^m = \min\left(\frac{1} {\lambda^m(\mathcal{H}_\theta \widetilde{\mathcal{L}}(\theta))} \left(e^{ -\lambda^m(\mathcal{H}_\theta \widetilde{\mathcal{L}}(\theta)) t } - 1\right), \eta_{\max}\right)\\c_u^m = \nabla_\theta \widetilde{\mathcal{L}}(\theta_u) \cdot v^m(\mathcal{H}_\theta \widetilde{\mathcal{L}}(\theta))\\w_u = \nabla_\theta \widetilde{\mathcal{L}}(\theta_u) - \sum_{m=1}^k c_u^m v^m(\mathcal{H}_\theta \widetilde{\mathcal{L}}(\theta))\\\eta_u = \max\left(\frac{1}{ \lambda^{k+1}(\mathcal{H}_\theta \widetilde{\mathcal{L}}(\theta)) }, \eta_{\max}\right)\end{aligned}\end{align} \]Where \(\theta_u\) is the parameters at iteration \(u\), \(\widetilde{\mathcal{L}}(\theta)\) is the training loss, \(\mathcal{H}_\theta \widetilde{\mathcal{L}}(\theta)\) is the Hessian matrix of the training loss, \(\lambda^m(\cdot)\) is the \(m\) th top eigenvalue, \(v^m(\cdot)\) is the \(m\) th top eigenvector, \(k\) is the expected dimension of the stiff subspace, and \(\eta_{\max}\) is a hyperparameter.
Essentially, where conventional gradient descent approximates the loss function as a locally linear function, the exponential Euler solver approximates it as a locally quadratic function. Since recovering all of the quadratic terms in the Taylor expansion requires computing all of the eigenvalues and eigenvectors of \(\mathcal{H}_\theta \widetilde{\mathcal{L}}(\theta)\), which is intractable, it instead only captures the largest and most influential quadratic components, and settles for a linear approximation in the other directions. It additionally calculates a learning rate by using one more eigenvalue, analogous to the
AdaptiveGradientDescent
solver, and flows along the gradient flow by a time step equal to that learning rate.The stiff dimension should be set to the dimension of the highly curved subspace. In practice, this is equal to the dimension of the network outputs, reduced by one if using the cross-entropy loss. For example, a classifier trained using cross-entropy loss on a dataset with 10 classes would have a stiff dimension of 9. A regression network trained using mean-squared error to predict a single value would have a stiff dimension of 1.
- __init__(params: Tensor, loss: LossFunction, max_step_size: float, stiff_dim: int, warmup_iters: int = 0, warmup_factor: float = 1.0)¶
- Args:
params (
torch.Tensor
): The parameters of the network that are being optimized.loss (
LossFunction
): The loss function.max_step_size (float): Maximum step size.
stiff_dim (int): Dimension of the expected “stiff” component of the loss landscape, generally equal to the number of network outputs.
warmup_iters (int): If set, the maximum step size is initially set to a lower value for this many iterations, to damp out initial transients.
warmup_factor (float): If set, the maximum step size is initially reduced by this factor, to damp out initial transients.
- step() ExponentialEulerSolverReport ¶
Takes a single step, returning a report.