Hugging Face transformers
库中 Trainer
常用方法和属性
Trainer
是 Hugging Face transformers
提供的 高层 API,用于 简化 PyTorch Transformer 模型的训练、评估和推理,支持 多 GPU 训练、梯度累积、混合精度训练 等。
1. Trainer
的常见属性
在 Trainer
初始化后,可以访问以下常见属性:
from transformers import Trainer
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
属性 | 作用 |
---|---|
trainer.model | 训练的 Transformer 模型 |
trainer.args | TrainingArguments 训练参数 |
trainer.train_dataset | 训练数据集 |
trainer.eval_dataset | 评估数据集 |
trainer.tokenizer | 使用的 tokenizer |
trainer.state | 训练状态,包括当前 epoch、global_step |
trainer.control | 训练控制参数 |
trainer.log_history | 训练日志历史 |
2. Trainer
的常用方法
方法 | 作用 |
---|---|
trainer.train() | 训练模型 |
trainer.evaluate() | 评估模型 |
trainer.predict(test_dataset) | 进行推理 |
trainer.save_model(path) | 保存模型 |
trainer.save_state() | 保存训练状态 |
trainer.load_state_dict(state_dict) | 加载训练状态 |
trainer.log(metrics) | 记录日志 |
trainer.add_callback(callback) | 添加训练回调 |
trainer.remove_callback(callback) | 移除训练回调 |
3. Trainer
详细用法
3.1. 训练模型
trainer.train()
- 训练日志默认每
logging_steps
步打印一次 - 可以在
TrainingArguments
中调整evaluation_strategy="epoch"
让训练每个 epoch 进行评估
3.2. 评估模型
trainer.evaluate()
返回 评估指标(如 loss
、accuracy
)。
3.3. 进行推理
predictions = trainer.predict(test_dataset)
print(predictions)
predictions.predictions
:模型输出的 logitspredictions.label_ids
:真实标签
3.4. 保存和加载模型
保存模型
trainer.save_model("./saved_model")
重新加载模型
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("./saved_model")
3.5. 记录日志
trainer.log({"loss": 0.1, "accuracy": 0.95})
3.6. 获取训练状态
print(trainer.state)
3.7. 添加回调
from transformers import EarlyStoppingCallback
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3))
4. Trainer
在不同任务中的应用
4.1. 文本分类
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
logging_dir="./logs",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
4.2. 命名实体识别(NER)
from transformers import AutoModelForTokenClassification
model = AutoModelForTokenClassification.from_pretrained("bert-base-cased", num_labels=9)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
trainer.train()
4.3. 机器翻译
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
trainer.train()
5. 总结
Trainer
是 transformers
提供的 高效训练工具,适用于 文本分类、NER、翻译、摘要等任务。
常用方法:
trainer.train()
进行 训练trainer.evaluate()
进行 评估trainer.predict(test_dataset)
进行 推理trainer.save_model(path)
保存模型trainer.log(metrics)
记录日志trainer.add_callback(callback)
添加回调