在 Hugging Face 的 PEFT(Parameter-Efficient Fine-Tuning)库中,PromptTuningConfig
是用于配置 Prompt Tuning 方法的类。Prompt Tuning 是一种参数高效的微调技术,通过在输入序列前添加可训练的虚拟 token(soft prompts)来调整预训练模型的行为,而不修改模型的原始权重。这种方法特别适合快速实验和资源受限的场景,因为它只需要优化少量的额外参数。以下是对 PromptTuningConfig
的详细讲解,包括其参数、用法、代码示例和注意事项。
1. PromptTuningConfig 概述
Prompt Tuning 的核心思想是为每个任务学习一组可训练的嵌入向量(虚拟 token),这些向量被添加到输入序列中,作为模型的“提示”。这些虚拟 token 的嵌入是可训练的,而预训练模型的权重保持冻结。相比于全参数微调或 LoRA,Prompt Tuning 的参数量极少,通常只有几千到几十万个参数,非常高效。
Prompt Tuning 的优势:
- 极低的参数量:仅训练少量的虚拟 token 嵌入。
- 快速实验:适合快速验证任务性能。
- 模块化:易于为不同任务切换不同的提示。
- 兼容性:支持 Hugging Face 的 Transformers 模型(如 BERT、GPT、T5 等)。
2. PromptTuningConfig 的主要参数
以下是 PromptTuningConfig
中常用的参数及其说明:
-
peft_type
(默认:"PROMPT_TUNING"
)- 指定 PEFT 方法类型,固定为
"PROMPT_TUNING"
。 - 通常无需手动设置,由类自动处理。
- 指定 PEFT 方法类型,固定为
-
task_type
(可选,字符串)- 指定任务类型,帮助 PEFT 确定模型结构和用途。常见选项包括:
"SEQ_CLS"
:序列分类(如情感分析)。"TOKEN_CLS"
:token 分类(如 NER)。"CAUSAL_LM"
:因果语言模型(如 GPT)。"SEQ_2_SEQ_LM"
:序列到序列模型(如 T5)。
- 如果不明确任务类型,可以不设置。
- 指定任务类型,帮助 PEFT 确定模型结构和用途。常见选项包括:
-
num_virtual_tokens
(整数,默认:20)- 指定添加到输入序列中的虚拟 token 数量。
- 典型值:10 到 100。更多的 token 增加表达能力,但也增加参数量。
- 参数量计算:
num_virtual_tokens * embedding_dim
(embedding_dim 通常是模型的隐藏层大小,如 BERT 的 768)。
-
prompt_tuning_init
(字符串,默认:"TEXT"
)- 指定虚拟 token 嵌入的初始化方式。选项包括:
"TEXT"
:基于文本初始化(需要提供prompt_tuning_init_text
和tokenizer_name_or_path
)。"RANDOM"
:随机初始化(使用高斯分布)。
- 文本初始化可以提高性能,但需要选择与任务相关的文本。
- 指定虚拟 token 嵌入的初始化方式。选项包括:
-
prompt_tuning_init_text
(字符串,可选)- 当
prompt_tuning_init="TEXT"
时,指定用于初始化的文本。 - 示例:
"Classify the sentiment of this sentence."
。 - 文本会被分词并转换为嵌入,用于初始化虚拟 token。
- 当
-
tokenizer_name_or_path
(字符串,可选)- 当
prompt_tuning_init="TEXT"
时,指定分词器的名称或路径。 - 示例:
"bert-base-uncased"
或本地分词器路径。 - 如果不提供,PEFT 会尝试使用模型的分词器。
- 当
-
num_layers
(整数,可选)- 指定在多少层中使用虚拟 token(仅对某些模型有效,如 GPT)。
- 默认:所有层。
-
modules_to_save
(列表,可选)- 指定需要全量微调的模块(不使用 Prompt Tuning)。例如,分类头的参数可以放入此列表。
- 示例:
["classifier"]
。
3. 使用 PromptTuningConfig 的基本流程
以下是一个使用 PromptTuningConfig
微调 BERT 模型(用于序列分类任务)的完整示例:
步骤 1:安装和导入库
pip install peft transformers torch datasets
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from peft import PromptTuningConfig, get_peft_model
from datasets import load_dataset
步骤 2:加载模型和数据集
# 加载预训练模型和分词器
model_name = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 加载数据集(以 GLUE 的 MRPC 为例)
dataset = load_dataset("glue", "mrpc")
# 数据预处理
def tokenize_function(examples):
return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length", max_length=128)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
步骤 3:配置 PromptTuningConfig
# 配置 Prompt Tuning
prompt_config = PromptTuningConfig(
task_type="SEQ_CLS", # 序列分类任务
num_virtual_tokens=20, # 虚拟 token 数量
prompt_tuning_init="TEXT", # 使用文本初始化
prompt_tuning_init_text="Classify the sentiment of this sentence.", # 初始化文本
tokenizer_name_or_path=model_name, # 分词器
modules_to_save=["classifier"] # 全量微调分类头
)
# 将 Prompt Tuning 应用到模型
peft_model = get_peft_model(model, prompt_config)
# 查看可训练参数
peft_model.print_trainable_parameters()
输出示例:
trainable params: 15,360 || all params: 109,499,138 || trainable%: 0.014
这表明只有约 0.014% 的参数需要训练(20 个虚拟 token × 768 维嵌入 + 分类头参数)。
步骤 4:训练模型
# 配置训练参数
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=1e-3, # Prompt Tuning 通常需要较高的学习率
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
)
# 初始化 Trainer
trainer = Trainer(
model=peft_model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
)
# 开始训练
trainer.train()
步骤 5:保存和加载 Prompt Tuning 模型
# 保存 Prompt Tuning 参数
peft_model.save_pretrained("./prompt_model")
# 加载 Prompt Tuning 模型
from peft import PeftModel
base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
loaded_model = PeftModel.from_pretrained(base_model, "./prompt_model")
步骤 6:推理
# 准备输入
inputs = tokenizer("This is a positive sentence.", return_tensors="pt")
# 推理
loaded_model.eval()
outputs = loaded_model(**inputs)
logits = outputs.logits
print(logits)
4. PromptTuningConfig 的优化建议
-
调整
num_virtual_tokens
:- 较小的值(如 10-20)适合简单任务或快速实验。
- 较大的值(如 50-100)适合复杂任务,但会增加参数量。
- 参数量 =
num_virtual_tokens * embedding_dim
,确保与硬件资源匹配。
-
选择初始化方式:
- 文本初始化(
prompt_tuning_init="TEXT"
):- 选择与任务相关的文本(如分类任务用“Classify this sentence.”)。
- 确保
prompt_tuning_init_text
的长度足够(会被截断到num_virtual_tokens
)。
- 随机初始化(
prompt_tuning_init="RANDOM"
):- 适合没有明确提示的场景,但可能需要更多训练轮次。
- 文本初始化(
-
学习率:
- Prompt Tuning 通常需要较高的学习率(如 1e-3 或 5e-4),因为只优化少量参数。
- 使用学习率调度器(如线性衰减)以提高稳定性。
-
任务类型:
- 确保
task_type
与模型和任务匹配。例如,GPT 使用"CAUSAL_LM"
,BERT 使用"SEQ_CLS"
。
- 确保
-
内存优化:
- Prompt Tuning 本身内存需求极低,但如果模型较大,可以结合 4-bit 量化:
from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig(load_in_4bit=True) model = AutoModelForSequenceClassification.from_pretrained(model_name, quantization_config=quantization_config)
- Prompt Tuning 本身内存需求极低,但如果模型较大,可以结合 4-bit 量化:
5. Prompt Tuning vs. 其他 PEFT 方法
-
Prompt Tuning vs. LoRA:
- 参数量:Prompt Tuning 参数量更少(仅虚拟 token 嵌入),LoRA 涉及低秩矩阵。
- 适用场景:Prompt Tuning 适合快速实验或小规模任务;LoRA 适合需要更高性能的复杂任务。
- 灵活性:LoRA 可以应用于模型的多个模块,Prompt Tuning 仅调整输入。
-
Prompt Tuning vs. Prefix Tuning:
- 结构:Prompt Tuning 在输入序列添加虚拟 token,Prefix Tuning 在 Transformer 每层添加前缀向量。
- 参数量:Prefix Tuning 通常参数量稍多,因为涉及多层。
- 适用模型:Prompt Tuning 更通用,Prefix Tuning 更适合生成任务(如语言模型)。
6. 常见问题与解答
-
Q1:如何选择
num_virtual_tokens
?- 从 10-20 开始实验。如果性能不足,逐渐增加到 50 或 100。
- 注意:过多的 token 可能导致过拟合或内存问题。
-
Q2:文本初始化效果不好怎么办?
- 尝试更贴合任务的
prompt_tuning_init_text
。 - 切换到随机初始化(
"RANDOM"
)并增加训练轮次。 - 检查分词器是否与模型匹配。
- 尝试更贴合任务的
-
Q3:Prompt Tuning 是否支持所有模型?
- 是的,支持所有 Hugging Face Transformers 模型,但效果因模型和任务而异(生成模型如 GPT 通常效果更好)。
-
Q4:性能不佳如何优化?
- 增加
num_virtual_tokens
或调整学习率。 - 检查数据集质量和预处理步骤。
- 尝试其他 PEFT 方法(如 LoRA)以比较性能。
- 增加
7. 进阶用法
-
多任务 Prompt Tuning:
- 为不同任务创建多个提示:
prompt_config_task1 = PromptTuningConfig( task_type="SEQ_CLS", num_virtual_tokens=20, prompt_tuning_init="TEXT", prompt_tuning_init_text="Task 1 prompt." ) prompt_config_task2 = PromptTuningConfig( task_type="SEQ_CLS", num_virtual_tokens=20, prompt_tuning_init="TEXT", prompt_tuning_init_text="Task 2 prompt." ) peft_model = get_peft_model(model, prompt_config_task1, adapter_name="task1") peft_model.add_adapter("task2", prompt_config_task2) peft_model.set_adapter("task1") # 切换提示
-
保存到 Hugging Face Hub:
peft_model.push_to_hub("your-username/prompt-model")
-
结合量化:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig quantization_config = BitsAndBytesConfig(load_in_4bit=True) model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config) peft_model = get_peft_model(model, prompt_config)
8. 进一步资源
- 官方文档:https://huggingface.co/docs/peft/conceptual_guides/prompt_tuning
- GitHub 示例:https://github.com/huggingface/peft/tree/main/examples
- 论文:Prompt Tuning 原始论文(https://arxiv.org/abs/2104.08691)。