adamw 优化器
时间: 2025-05-24 22:14:48 AIGC 浏览: 50
### AdamW 优化器的原理与实现
AdamW 是一种改进版的 Adam 优化算法,它通过将权重衰减(Weight Decay)独立于学习率的方式引入到梯度下降过程中。这种方法解决了原始 Adam 中权重衰减与自适应学习率之间的冲突问题[^7]。
#### AdamW 的核心公式
AdamW 的更新规则可以分为以下几个部分:
1. **动量计算**
动量项用于平滑历史梯度的变化趋势,具体公式为:
\[
m_t = \beta_1 m_{t-1} + (1-\beta_1)\nabla L(\theta)
\]
其中 \(m_t\) 表示第 t 步的动量估计值,\(\beta_1\) 是一阶矩估计的指数衰减速率[\^8]。
2. **RMS 计算**
RMS(Root Mean Square)用来跟踪过去梯度平方的加权平均值,其公式为:
\[
v_t = \beta_2 v_{t-1} + (1-\beta_2)(\nabla L(\theta))^2
\]
这里 \(v_t\) 表示二阶矩估计值,\(\beta_2\) 则是二阶矩估计的指数衰减速率[\^9]。
3. **偏差修正**
对初始阶段的动量和 RMS 值进行偏差校正以消除初始化带来的影响:
\[
\hat{m}_t = \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t}
\]
4. **参数更新**
结合以上步骤以及加入权重衰减后的最终参数更新方式为:
\[
\theta_{t+1} = \theta_t - \eta (\frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon}) - \lambda \cdot \theta_t
\]
上述公式中,\(\eta\) 是学习率,\(\epsilon\) 防止除零错误的小常数,而 \(\lambda\) 控制着权重衰减的程度[\^10]。
#### PyTorch 实现 AdamW
在实际操作层面,PyTorch 提供了一个内置版本的 AdamW 优化器,可以直接导入并使用:
```python
import torch.optim as optim
# 定义模型和其他超参...
model = YourModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)
for epoch in range(num_epochs):
for inputs, labels in dataloader:
optimizer.zero_grad() # 清空之前的梯度
outputs = model(inputs) # 前向传播
loss = criterion(outputs, labels) # 计算损失函数
loss.backward() # 反向传播获取梯度
optimizer.step() # 更新所有参数
```
---
### 使用注意事项
尽管 AdamW 性能优越,在某些特定条件下仍需要注意一些事项:
- 学习率的选择至关重要,过高的学习率可能导致训练不稳定甚至发散;反之则收敛速度较慢。
- 权重衰减值应依据具体任务调整,默认值可能并不适用于所有的网络结构或数据集特性。
- 如果遇到内存不足的情况,可尝试降低批量大小(batch size),或者采用梯度累积技术(grad accumulation)作为变通办法[\^11]。
---
阅读全文
相关推荐




















