732 字
4 分钟
次浏览
AdmaW(part I) Weight Decay == L2 Regularization?

0. 前言#

上一篇 Blog 中探讨了 L1 Regularization 和 L2 Regularization. 我们说到: 对损失函数添加 L2 Regularization , 最后对 w 使用梯度下降的时候, 实际是对 w 做了权重衰减.

然而, 上述等价性只在优化器为随机梯度下降(SGD)时成立(下边我们会证明). 在其他情况下, 特别是在训练深度学习模型时, 经常使用Adam优化器 , 上述结论不成立.

本篇 Blog 主要探讨在使用 Adam 的时候 Weight Decay 和 L2 Regularization 的关系, 以及当更新参数引入 momentum之后他们之间的关系 , 最后介绍 AdamW 优化器. 文中符号都尽量与 AdamW paper 中的一致.

NOTE

阅读前, 需要你 : 有高数基础知识, 线代基础知识, 当然还要有 ML和 DL 的知识背景.

1. SGD场景下#

1.1 无 momentum#

weight decay 的公式:

θt+1=(1λ)θtαft(θt)\theta_{t+1} = (1 - \lambda ) \theta_{t} - \alpha \nabla f_t(\theta_{t})

这里 α\alpha 是学习率 , λ\lambda 是 weight decay 的系数. 如果对损失函数施加 L2 Regularization :

ftreg(θ)=ft(θ)+λ2θ22f_t^{reg}(\theta) = f_t(\theta) + \frac {\lambda '} {2} {\|\theta\|_2}^2

使用梯度下降:

θt+1=θtαftreg(θt)=θtαft(θt)αλθt=(1αλ)θtαft(θt)\begin{align*} \theta_{t+1} &= \theta_{t} - \alpha \nabla f_t^{reg}(\theta_{t}) \\ &= \theta_{t} - \alpha \nabla f_t(\theta_{t}) - \alpha \lambda ' \theta_{t}\\ &= (1 - \alpha \lambda ' ) \theta_{t} - \alpha \nabla f_t(\theta_{t}) \end{align*}

如果想让 weight decay 和 带L2 Regularization 等价 , 则应有 αλ=λ\alpha \lambda' = \lambda , 显然对于SGD我们可以做到这个事情. 也就是说 在SGD优化器下, weight decay 和 带L2 Regularization 等价. 不过有个问题, 假设我们存在一个最优的weight decay系数 λ\lambda , 并且置了 L2 的系数 λ\lambda' , 这样就会把系统的学习率给固定了. 换句话说, 这时 weight decay 的系数 和 L2 Regularization 的系数是耦合的. 二者会相互影响.

1.2 添加 momentum#

如果在 L2 Regularization 的基础上添加 momentum 项

gt=ft1(θt1)+λθt1g_t = \nabla f_{t-1}(\theta_{t-1}) + \lambda ' \theta_{t-1}mt=β1mt1+gtm_t = \beta_{1}m_{t-1} + g_t

SGD with momentum and weight decay (L2 Regularization) 式子将会变为:

θt=θt1αmt=θt1α(β1mt1ft1(θt1)λθt1)=(1αλ)θt1weight decayαft1(θt1)gradient descentαβ1mt1momentum\begin{align*} \theta_{t} &= \theta_{t-1} - \alpha m_t \\ &= \theta_{t-1} - \alpha (\beta_{1}m_{t-1} - \nabla f_{t-1}(\theta_{t-1}) - \lambda ' \theta_{t-1}) \\ &= \underbrace{(1 - \alpha \lambda ' ) \theta_{t-1}}_{weight \ decay} - \underbrace{\alpha \nabla f_{t-1}(\theta_{t-1})}_{gradient \ descent} - \underbrace{\alpha \beta_{1}m_{t-1}}_{momentum} \end{align*}

这里, 学习率 α\alpha 和 L2 Regularization 的系数还是耦合, 并且还和 momentum 的系数也耦合上了.

WARNING

耦合归耦合, 但是该说不说, 在SGD场景下, Weight Decay == L2 Regularization 是可以成立的. 无论加不加 momentum

2. Adam场景下#

这里就不敲公式了,给出 AdamW paper 附录的证明.

image.png

我们知道, 在 Adam 优化器中, 学习率是自适应变化的, 上图中 MtM_t 就表示给学习率乘的自适应系数矩阵. 要想

λθt=αλMtθt\lambda \theta_{t} = \alpha \lambda ' M_t \theta_{t}

就必须让

λ=αλMt\lambda = \alpha \lambda ' M_t

其中 λ ,α ,λ\lambda \ , \alpha \ ,\lambda' 三兄弟都是常数, MtM_t 又是自适应系数, 显然是不能实现上边的目标的,

WARNING

因此对于类似 Adam 这种自适应学习率的算法, Weight Decay \neq L2 Regularization . 无论加不加 momentum

Reference#

AdmaW(part I) Weight Decay == L2 Regularization?
https://xuchenhui.cc/posts/2024-04-20-weight-decay-and-l2-regularization/
作者
CHENHUI
发布于
2024-04-20
许可协议
CC BY-NC-SA 4.0
📖 目录