We Don’t Need To Worry About Overfitting Anymore
Last Updated on January 6, 2023 by Editorial Team
Author(s): Sean Benhur J
Machine Learning
We Don’t Need To Worry About Overfitting Anymore
Motivated by prior work connecting the geometry of the loss landscape
and generalization, we introduce a novel, effective procedure for instead simulta- neously minimizing loss value and loss sharpness. In particular, our procedure,
Sharpness-Aware Minimization (SAM), seeks parameters that lie in neighbor-
hoods having uniformly low loss; this formulation results in a min-max optimiza- tion problem on which gradient descent can be performed efficiently. We present empirical results showing that SAM improves model generalization across a variety of benchmark datasets[1]
Source: Sharpness Awareness Minimization Paper [1]
In Deep Learning we use optimization algorithms such as SGD/Adam to achieve convergence in our model, which leads to finding the global minima, i.e a point where the loss of the training dataset is low. But several kinds of research such as Zhang et al have shown, many networks can easily memorize the training data and have the capacity to readily overfit, To prevent this problem and add more generalization, Researchers at Google have published a new paper called Sharpness Awareness Minimization which provides State of the Art results on CIFAR10 and other datasets.
In this article, we will look at why SAM can achieve better generalization and how we can implement SAM in Pytorch.
Why SAM works!?
In Gradient descent or any other optimization algorithm, our goal is to find a parameter that has a low loss value
But SAM achieves better generalization than any other normal optimization method by focusing on seeking parameters that lie in neighborhoods having uniformly low loss value (rather than parameters that only themselves have low loss value)
Since computing neighborhood parameters in addition to computing a single parameter, the loss landscape is flatter comparing to other optimization methods, which in turn increases generalization of the model.
Note: SAM is not a new optimizer, It is used with any other common optimizers such as SGD/Adam
Implementing SAM in Pytorch:
Implementing SAM in Pytorch is very simple and straightforward
Code explanation,
- At first, we inherit from the optimizer class from Pytorch to create an optimizer, though SAM is not a new optimizer but to update gradients(with the help of base optimizer) at each step we need to inherit that class
- The class accepts the model parameters, a base optimizer and a rho, which is the size of the neighborhood for computing the maximum loss
- Before moving on to the next steps let’s have a look at the pseudocode mentioned in the paper which will help us to understand the above code without math.
- As we see in the pseudocode after computing the first backward pass, we compute the epsilon and add it to the parameters, those steps are implemented in the method first_step on the above python code
- Now after computing the first step we have to get back to the previous weight for computing the actual step of a base optimizer, these steps are implemented in the function second_step
- The function _grad_norm is used to return the norm of the matrix vectors, which is said in the 10th line of the pseudocode
- After constructing this class you can simply use this for your deep learning projects by following the below snippet in the training function.
Finishing Thoughts:
Though SAM achieves better generalization, the main con of this method is, it takes twice the time for training since it computes forward and backward passes two times to compute the sharpness awareness gradient. Other than that SAM has also proved its effect on the recently published NFNETS, which is a current State of the Art for ImageNet, In the future, we can expect more and more papers utilizing this technique to achieve better generalization.
If you’ve enjoyed this article or have any questions, please feel free to connect me on LinkedIn
References:
[1] Sharpness-Aware Minimization for Efficiently Improving Generalization
[2] Unofficial Implementation of SAM by Ryuichiro Hataya
We don’t need to worry about Overfitting anymore was originally published in Towards AI on Medium, where people are continuing the conversation by highlighting and responding to this story.
Published via Towards AI