Hugging Face transformers.optimization
中的 get_scheduler
函数
get_scheduler
是 Hugging Face transformers.optimization
模块中的一个 学习率调度器(LR Scheduler)获取函数,用于 根据训练步数和策略调整学习率,适用于 Transformer 模型的优化。
1. get_scheduler
的基本用法
1.1. 导入 get_scheduler
from transformers import get_scheduler
1.2. 主要参数
scheduler = get_scheduler(
name="linear", # 调度器类型
optimizer=optimizer, # 绑定优化器
num_warmup_steps=0, # 预热步数
num_training_steps=1000 # 训练总步数
)
参数 | 作用 | 示例 |
---|---|---|
name | 调度器类型 | "linear" |
optimizer | 训练使用的优化器 | AdamW |
num_warmup_steps | 预热步数 | 0 |
num_training_steps | 训练总步数 | 1000 |
2. get_scheduler
支持的调度器类型
调度器名称 | 作用 | 适用任务 |
---|---|---|
"linear" | 线性衰减 | 通用 |
"cosine" | 余弦退火 | 长期训练 |
"cosine_with_restarts" | 余弦退火(带重启) | 阶段性学习率调整 |
"polynomial" | 多项式衰减 | 自定义学习率曲线 |
"constant" | 固定学习率 | 调试阶段 |
"constant_with_warmup" | 预热后固定学习率 | 微调预训练模型 |
3. get_scheduler
详细示例
3.1. 线性衰减 (linear
)
适用于 常规 NLP 训练,学习率随时间线性下降:
from transformers import get_scheduler
import torch
# 定义优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
# 训练步数
num_training_steps = 1000
num_warmup_steps = 100
# 获取调度器
scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
# 模拟学习率变化
for step in range(10):
optimizer.step()
scheduler.step()
print(f"Step {step+1}: Learning Rate = {scheduler.get_last_lr()[0]}")
3.2. 余弦衰减 (cosine
)
适用于 长期训练任务,学习率随着训练进程 逐渐减小:
scheduler = get_scheduler("cosine", optimizer, num_warmup_steps=100, num_training_steps=1000)
3.3. 余弦衰减(带重启) (cosine_with_restarts
)
适用于 阶段性训练(如多轮训练),学习率 会周期性重置:
scheduler = get_scheduler("cosine_with_restarts", optimizer, num_warmup_steps=100, num_training_steps=1000)
3.4. 预热后固定学习率 (constant_with_warmup
)
适用于 微调预训练模型,前几步 缓慢增加学习率,然后保持恒定:
scheduler = get_scheduler("constant_with_warmup", optimizer, num_warmup_steps=100, num_training_steps=1000)
4. 在 Trainer
训练中使用 get_scheduler
在 Hugging Face Trainer
训练时,可以自定义 get_scheduler
:
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="./results",
learning_rate=5e-5,
lr_scheduler_type="linear", # 这里指定调度器
warmup_steps=100,
num_train_epochs=3,
)
5. 总结
get_scheduler
是 transformers.optimization
中的 学习率调度器 获取函数,适用于 Transformer 模型训练。
常见调度器:
"linear"
:线性衰减(推荐默认)"cosine"
:余弦退火"cosine_with_restarts"
:周期性重置学习率"constant_with_warmup"
:预热后固定"polynomial"
:自定义多项式衰减
如果你的训练 需要调整学习率随训练进程的变化方式,推荐使用 get_scheduler
。