22 Jul 2023
On the data side data augmentation helps models generalize.
On the model side weight decay can help us generalize.
To prevent overfitting we shouldn’t allow our models to get too complex (eg. a polynomial that is actually very overfit like in the pic). Having fewer parameters can prevent your model from getting overly complex, but its a limiting strategy.
To penalize complexity we can
→ add all our parameters (weights) to the loss function,
but since some of them are positive and some negative, we can
→ add the squares of all the parameters to the loss function
but it might result in loss getting huge and the best model would have all parameters set to 0
→ we multiply the sum of squares with another smaller number called *weight decay*, or wd.
Our loss function will look as follows:
Loss = MSE(y_hat, y) + wd * sum(w^2)
y_hat
= predicted or estimated value of the target variable
y
= actual target value
When we update weights using gradient descent we do the follwing:
w(t) = w(t-1) - lr * dLoss / dw
lr
= learning rate, a hyperparameter that determines the step size or the rate at which weights are updated during training
dLoss
= derivative of the loss function wrt the weights, represents how the loss function changes as the weights are modified
Generally wd = 0.1
works pretty well.
→ Too much weight decay then no matter how much you train the model will never fit quire well.
→ Too much weight decay and you can still train well, but you need to stop a bit early.