22 Jul 2023

Weight Decay

On the data side data augmentation helps models generalize.

On the model side weight decay can help us generalize.

To prevent overfitting we shouldn’t allow our models to get too complex (eg. a polynomial that is actually very overfit like in the pic). Having fewer parameters can prevent your model from getting overly complex, but its a limiting strategy.

To penalize complexity we can

→ add all our parameters (weights) to the loss function,

but since some of them are positive and some negative, we can

→ add the squares of all the parameters to the loss function

but it might result in loss getting huge and the best model would have all parameters set to 0

→ we multiply the sum of squares with another smaller number called *weight decay*, or wd.

Our loss function will look as follows:

Loss = MSE(y_hat, y) + wd * sum(w^2)

y_hat = predicted or estimated value of the target variable

y = actual target value

When we update weights using gradient descent we do the follwing:

w(t) = w(t-1) - lr * dLoss / dw

lr = learning rate, a hyperparameter that determines the step size or the rate at which weights are updated during training

dLoss = derivative of the loss function wrt the weights, represents how the loss function changes as the weights are modified

The value of wd

Generally wd = 0.1 works pretty well.

→ Too much weight decay then no matter how much you train the model will never fit quire well.

→ Too much weight decay and you can still train well, but you need to stop a bit early.