【Hugging Face】PEFT 库中的 PromptTuningConfig 常用的参数和说明

在 Hugging Face 的 PEFT(Parameter-Efficient Fine-Tuning)库中,PromptTuningConfig 是用于配置 Prompt Tuning 方法的类。Prompt Tuning 是一种参数高效的微调技术,通过在输入序列前添加可训练的虚拟 token(soft prompts)来调整预训练模型的行为,而不修改模型的原始权重。这种方法特别适合快速实验和资源受限的场景,因为它只需要优化少量的额外参数。以下是对 PromptTuningConfig 的详细讲解,包括其参数、用法、代码示例和注意事项。


1. PromptTuningConfig 概述

Prompt Tuning 的核心思想是为每个任务学习一组可训练的嵌入向量(虚拟 token),这些向量被添加到输入序列中,作为模型的“提示”。这些虚拟 token 的嵌入是可训练的,而预训练模型的权重保持冻结。相比于全参数微调或 LoRA,Prompt Tuning 的参数量极少,通常只有几千到几十万个参数,非常高效。

Prompt Tuning 的优势

  • 极低的参数量:仅训练少量的虚拟 token 嵌入。
  • 快速实验:适合快速验证任务性能。
  • 模块化:易于为不同任务切换不同的提示。
  • 兼容性:支持 Hugging Face 的 Transformers 模型(如 BERT、GPT、T5 等)。

2. PromptTuningConfig 的主要参数

以下是 PromptTuningConfig 中常用的参数及其说明:

  • peft_type(默认:"PROMPT_TUNING"

    • 指定 PEFT 方法类型,固定为 "PROMPT_TUNING"
    • 通常无需手动设置,由类自动处理。
  • task_type(可选,字符串)

    • 指定任务类型,帮助 PEFT 确定模型结构和用途。常见选项包括:
      • "SEQ_CLS":序列分类(如情感分析)。
      • "TOKEN_CLS":token 分类(如 NER)。
      • "CAUSAL_LM":因果语言模型(如 GPT)。
      • "SEQ_2_SEQ_LM":序列到序列模型(如 T5)。
    • 如果不明确任务类型,可以不设置。
  • num_virtual_tokens(整数,默认:20)

    • 指定添加到输入序列中的虚拟 token 数量。
    • 典型值:10 到 100。更多的 token 增加表达能力,但也增加参数量。
    • 参数量计算:num_virtual_tokens * embedding_dim(embedding_dim 通常是模型的隐藏层大小,如 BERT 的 768)。
  • prompt_tuning_init(字符串,默认:"TEXT"

    • 指定虚拟 token 嵌入的初始化方式。选项包括:
      • "TEXT":基于文本初始化(需要提供 prompt_tuning_init_texttokenizer_name_or_path)。
      • "RANDOM":随机初始化(使用高斯分布)。
    • 文本初始化可以提高性能,但需要选择与任务相关的文本。
  • prompt_tuning_init_text(字符串,可选)

    • prompt_tuning_init="TEXT" 时,指定用于初始化的文本。
    • 示例:"Classify the sentiment of this sentence."
    • 文本会被分词并转换为嵌入,用于初始化虚拟 token。
  • tokenizer_name_or_path(字符串,可选)

    • prompt_tuning_init="TEXT" 时,指定分词器的名称或路径。
    • 示例:"bert-base-uncased" 或本地分词器路径。
    • 如果不提供,PEFT 会尝试使用模型的分词器。
  • num_layers(整数,可选)

    • 指定在多少层中使用虚拟 token(仅对某些模型有效,如 GPT)。
    • 默认:所有层。
  • modules_to_save(列表,可选)

    • 指定需要全量微调的模块(不使用 Prompt Tuning)。例如,分类头的参数可以放入此列表。
    • 示例:["classifier"]

3. 使用 PromptTuningConfig 的基本流程

以下是一个使用 PromptTuningConfig 微调 BERT 模型(用于序列分类任务)的完整示例:

步骤 1:安装和导入库
pip install peft transformers torch datasets
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from peft import PromptTuningConfig, get_peft_model
from datasets import load_dataset
步骤 2:加载模型和数据集
# 加载预训练模型和分词器
model_name = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 加载数据集(以 GLUE 的 MRPC 为例)
dataset = load_dataset("glue", "mrpc")

# 数据预处理
def tokenize_function(examples):
    return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length", max_length=128)

tokenized_dataset = dataset.map(tokenize_function, batched=True)
步骤 3:配置 PromptTuningConfig
# 配置 Prompt Tuning
prompt_config = PromptTuningConfig(
    task_type="SEQ_CLS",                     # 序列分类任务
    num_virtual_tokens=20,                   # 虚拟 token 数量
    prompt_tuning_init="TEXT",               # 使用文本初始化
    prompt_tuning_init_text="Classify the sentiment of this sentence.",  # 初始化文本
    tokenizer_name_or_path=model_name,       # 分词器
    modules_to_save=["classifier"]           # 全量微调分类头
)

# 将 Prompt Tuning 应用到模型
peft_model = get_peft_model(model, prompt_config)

# 查看可训练参数
peft_model.print_trainable_parameters()

输出示例:

trainable params: 15,360 || all params: 109,499,138 || trainable%: 0.014

这表明只有约 0.014% 的参数需要训练(20 个虚拟 token × 768 维嵌入 + 分类头参数)。

步骤 4:训练模型
# 配置训练参数
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=1e-3,                     # Prompt Tuning 通常需要较高的学习率
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)

# 初始化 Trainer
trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
)

# 开始训练
trainer.train()
步骤 5:保存和加载 Prompt Tuning 模型
# 保存 Prompt Tuning 参数
peft_model.save_pretrained("./prompt_model")

# 加载 Prompt Tuning 模型
from peft import PeftModel
base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
loaded_model = PeftModel.from_pretrained(base_model, "./prompt_model")
步骤 6:推理
# 准备输入
inputs = tokenizer("This is a positive sentence.", return_tensors="pt")

# 推理
loaded_model.eval()
outputs = loaded_model(**inputs)
logits = outputs.logits
print(logits)

4. PromptTuningConfig 的优化建议

  • 调整 num_virtual_tokens

    • 较小的值(如 10-20)适合简单任务或快速实验。
    • 较大的值(如 50-100)适合复杂任务,但会增加参数量。
    • 参数量 = num_virtual_tokens * embedding_dim,确保与硬件资源匹配。
  • 选择初始化方式

    • 文本初始化prompt_tuning_init="TEXT"):
      • 选择与任务相关的文本(如分类任务用“Classify this sentence.”)。
      • 确保 prompt_tuning_init_text 的长度足够(会被截断到 num_virtual_tokens)。
    • 随机初始化prompt_tuning_init="RANDOM"):
      • 适合没有明确提示的场景,但可能需要更多训练轮次。
  • 学习率

    • Prompt Tuning 通常需要较高的学习率(如 1e-3 或 5e-4),因为只优化少量参数。
    • 使用学习率调度器(如线性衰减)以提高稳定性。
  • 任务类型

    • 确保 task_type 与模型和任务匹配。例如,GPT 使用 "CAUSAL_LM",BERT 使用 "SEQ_CLS"
  • 内存优化

    • Prompt Tuning 本身内存需求极低,但如果模型较大,可以结合 4-bit 量化:
      from transformers import BitsAndBytesConfig
      quantization_config = BitsAndBytesConfig(load_in_4bit=True)
      model = AutoModelForSequenceClassification.from_pretrained(model_name, quantization_config=quantization_config)
      

5. Prompt Tuning vs. 其他 PEFT 方法

  • Prompt Tuning vs. LoRA

    • 参数量:Prompt Tuning 参数量更少(仅虚拟 token 嵌入),LoRA 涉及低秩矩阵。
    • 适用场景:Prompt Tuning 适合快速实验或小规模任务;LoRA 适合需要更高性能的复杂任务。
    • 灵活性:LoRA 可以应用于模型的多个模块,Prompt Tuning 仅调整输入。
  • Prompt Tuning vs. Prefix Tuning

    • 结构:Prompt Tuning 在输入序列添加虚拟 token,Prefix Tuning 在 Transformer 每层添加前缀向量。
    • 参数量:Prefix Tuning 通常参数量稍多,因为涉及多层。
    • 适用模型:Prompt Tuning 更通用,Prefix Tuning 更适合生成任务(如语言模型)。

6. 常见问题与解答

  • Q1:如何选择 num_virtual_tokens

    • 从 10-20 开始实验。如果性能不足,逐渐增加到 50 或 100。
    • 注意:过多的 token 可能导致过拟合或内存问题。
  • Q2:文本初始化效果不好怎么办?

    • 尝试更贴合任务的 prompt_tuning_init_text
    • 切换到随机初始化("RANDOM")并增加训练轮次。
    • 检查分词器是否与模型匹配。
  • Q3:Prompt Tuning 是否支持所有模型?

    • 是的,支持所有 Hugging Face Transformers 模型,但效果因模型和任务而异(生成模型如 GPT 通常效果更好)。
  • Q4:性能不佳如何优化?

    • 增加 num_virtual_tokens 或调整学习率。
    • 检查数据集质量和预处理步骤。
    • 尝试其他 PEFT 方法(如 LoRA)以比较性能。

7. 进阶用法

  • 多任务 Prompt Tuning

    • 为不同任务创建多个提示:
    prompt_config_task1 = PromptTuningConfig(
        task_type="SEQ_CLS",
        num_virtual_tokens=20,
        prompt_tuning_init="TEXT",
        prompt_tuning_init_text="Task 1 prompt."
    )
    prompt_config_task2 = PromptTuningConfig(
        task_type="SEQ_CLS",
        num_virtual_tokens=20,
        prompt_tuning_init="TEXT",
        prompt_tuning_init_text="Task 2 prompt."
    )
    
    peft_model = get_peft_model(model, prompt_config_task1, adapter_name="task1")
    peft_model.add_adapter("task2", prompt_config_task2)
    peft_model.set_adapter("task1")  # 切换提示
    
  • 保存到 Hugging Face Hub

    peft_model.push_to_hub("your-username/prompt-model")
    
  • 结合量化

    from transformers import AutoModelForCausalLM, BitsAndBytesConfig
    quantization_config = BitsAndBytesConfig(load_in_4bit=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
    peft_model = get_peft_model(model, prompt_config)
    

8. 进一步资源

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值