Hugging Face transformers
库中 model
的常用方法和属性
在 transformers
库中,model
代表 预训练的 Transformer 模型,可用于 文本分类、问答、文本生成等任务。不同任务的 model
可能会有不同的方法和属性,但它们共享许多常见功能。
1. model
的常见属性
在加载 AutoModel
或 AutoModelForXXX
后,可以使用以下属性:
from transformers import AutoModel
model = AutoModel.from_pretrained("bert-base-uncased")
属性 | 作用 | 示例 |
---|---|---|
model.config | 获取模型配置 | print(model.config) |
model.num_parameters() | 获取模型参数数量 | model.num_parameters() |
model.device | 查看模型在哪个设备上 | model.device |
model.dtype | 获取数据类型(fp32, fp16, bf16) | model.dtype |
model.embeddings | 获取模型的嵌入层 | model.embeddings |
model.encoder | 获取 Transformer 编码器 | model.encoder |
model.lm_head | 语言模型的输出层(如 GPT-2, T5) | model.lm_head |
2. model
的常用方法
方法 | 作用 |
---|---|
model.forward(input_ids, attention_mask) | 前向传播(通常直接调用 model(...) ) |
model.to(device) | 将模型移动到 CPU/GPU |
model.eval() | 进入评估模式 |
model.train() | 进入训练模式 |
model.generate(input_ids, max_length=50) | 生成文本(适用于 GPT-2, T5, BART) |
model.save_pretrained(path) | 保存模型 |
model.load_state_dict(torch.load(path)) | 加载训练好的参数 |
model.parameters() | 获取所有参数 |
model.named_parameters() | 获取所有参数及名称 |
3. model
详细用法
3.1. 前向传播
所有 Transformer 模型都支持 forward()
,但通常直接调用 model(...)
:
from transformers import AutoTokenizer, AutoModel
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
# 处理输入文本
text = "Hugging Face is great!"
inputs = tokenizer(text, return_tensors="pt")
# 前向传播
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)
输出
torch.Size([1, 7, 768])
last_hidden_state.shape
解释:
1
:批量大小7
:序列长度(包括 [CLS] 和 [SEP])768
:隐藏层维度
3.2. 进入训练或评估模式
model.train() # 启用 dropout、BatchNorm
model.eval() # 关闭 dropout,进入推理模式
3.3. 将模型移动到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
3.4. 计算模型参数数量
num_params = model.num_parameters()
print(f"模型参数数量: {num_params / 1e6:.2f}M")
3.5. 生成文本(适用于 GPT-2, BART, T5)
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
input_text = "Hugging Face is"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
output_ids = model.generate(input_ids, max_length=50)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)
3.6. 保存和加载模型
保存模型
model.save_pretrained("./my_model")
加载模型
from transformers import AutoModel
model = AutoModel.from_pretrained("./my_model")
3.7. 加载自定义权重
如果你有一个训练好的 .bin
权重:
import torch
state_dict = torch.load("pytorch_model.bin")
model.load_state_dict(state_dict)
4. model
在不同任务中的应用
4.1. 文本分类
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
inputs = tokenizer("Hugging Face is great!", return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax().item()
print(predicted_class)
4.2. 问答任务
from transformers import AutoModelForQuestionAnswering
model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad")
question = "Where is Hugging Face based?"
context = "Hugging Face is based in New York City."
inputs = tokenizer(question, context, return_tensors="pt")
outputs = model(**inputs)
start = outputs.start_logits.argmax()
end = outputs.end_logits.argmax() + 1
answer = tokenizer.decode(inputs.input_ids[0][start:end])
print(answer)
4.3. 生成式任务(翻译、摘要)
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
input_text = "Hugging Face provides NLP tools for AI applications."
inputs = tokenizer(input_text, return_tensors="pt")
output_ids = model.generate(inputs.input_ids, max_length=30)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)
5. 总结
model
在 transformers
中是核心组件,适用于 文本分类、问答、翻译、文本生成等任务。
常用方法:
model(...)
进行 前向传播model.generate()
进行 文本生成model.to(device)
移动模型到 GPUmodel.eval()
进入推理模式model.train()
进入训练模式model.save_pretrained(path)
保存模型model.num_parameters()
查看参数数量
不同任务使用不同的 AutoModelForXXX
:
- 文本分类:
AutoModelForSequenceClassification
- 问答:
AutoModelForQuestionAnswering
- 文本生成:
AutoModelForCausalLM
- 翻译/摘要:
AutoModelForSeq2SeqLM