import torch
import json
import os
import argparse
import numpy as np
import re
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from PIL import Image
from peft import LoraConfig, get_peft_model
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
TrainingArguments,
BitsAndBytesConfig,
GenerationConfig,
get_cosine_schedule_with_warmup
)
import torch.optim as optim
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, classification_report
import warnings
warnings.filterwarnings("ignore", message="Could not find a config file")
# 标签映射定义
TASK1_LABEL_MAP = {"无害": 0, "有害": 1}
TASK2_LABEL_MAP = {
0: "无害",
1: "针对性有害",
2: "一般性有害",
3: "性暗示",
4: "沮丧文化"
}
class HarmfulMemeDataset(Dataset):
def __init__(self, annotation_path, processor, task_name, image_root=None):
print(f"Loading annotations from: {annotation_path}")
self.image_root = image_root
with open(annotation_path, 'r', encoding='utf-8') as f:
self.data = json.load(f)
self.processor = processor
self.task_name = task_name
self.label_token_ids = self._get_label_tokens()
def _get_label_tokens(self):
"""获取标签对应的token ID"""
if self.task_name == "task_1":
labels = ["无害", "有害"]
else:
labels = [str(i) for i in range(5)] # 0-4
token_ids = {}
tokenizer = self.processor.tokenizer
for label in labels:
tokens = tokenizer.encode(label)
token_ids[label] = tokens # 保存完整的token序列
if len(tokens) > 1:
print(f"警告: 标签 '{label}' 被分词为多个token: {tokens}")
return token_ids
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
if self.image_root:
image_path = os.path.join(self.image_root, sample["path"])
else:
image_path = sample["path"]
if not os.path.exists(image_path):
raise FileNotFoundError(f"图片不存在:{image_path}")
image = Image.open(image_path).convert("RGB")
if self.task_name == "task_1":
user_content = f"任务:有害模因检测(是否有害)。\n输入文本:'{sample['text']}'\n请判断此图文组合是否有害。"
raw_label = sample['label']# 原始标签是字符串("无害"或"有害")
# 根据TASK1_LABEL_MAP进行映射
if isinstance(raw_label, int):
# 如果是整数,转换为对应的字符串标签
label_map = {v: k for k, v in TASK1_LABEL_MAP.items()} # 反转映射
label = label_map.get(raw_label, "无害") # 默认值为"无害"
else:
# 如果已经是字符串,直接使用
label = raw_label
label_token = self.label_token_ids[label]
assistant_content = f"结论:{label}。\n理由:{sample['explanation']}"
else:
user_content = f"任务:有害模因类型分类。\n输入文本:'{sample['text']}'\n请判断此图文组合的有害类型(0-4)。"
raw_label = str(sample['type'])# 将整数标签转换为字符串
label = str(raw_label)
label_token = self.label_token_ids[label]
assistant_content = f"结论:{label}。\n理由:{sample['explanation']}"
messages = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": user_content}]},
{"role": "assistant", "content": [{"type": "text", "text": assistant_content}]}
]
prompt = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
chat_format="chatml"
)
# 使用processor处理文本,避免直接调用tokenizer.encode
encoding = self.processor(
text=prompt,
images=None, # 只处理文本
return_tensors="pt",
padding=False,
truncation=False
)
prompt_tokens = encoding["input_ids"][0].tolist()
# 找到结论标签的位置
conclusion_start = self.processor.tokenizer.encode("结论:")
# 在prompt中查找"结论:"的位置
start_idx = -1
for i in range(len(prompt_tokens) - len(conclusion_start) + 1):
if prompt_tokens[i:i+len(conclusion_start)] == conclusion_start:
start_idx = i + len(conclusion_start)
break
inputs = self.processor(
text=prompt,
images=image,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512
)
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
# 创建标签张量,只标记结论位置
labels = torch.full_like(inputs["input_ids"], fill_value=-100, dtype=torch.long)
if start_idx != -1 and start_idx < len(labels):
# 标记整个标签token序列
label_tokens = self.label_token_ids[label]
for i, token_id in enumerate(label_tokens):
if start_idx + i < len(labels):
labels[start_idx + i] = token_id
inputs["labels"] = labels
return inputs
def parse_generated_text(self,text):
"""解析生成的文本,提取结论标签"""
conclusion_match = re.search(r"结论[::]\s*(\S+)", text)
if not conclusion_match:
return None
conclusion = conclusion_match.group(1).strip().rstrip('。.')
# 处理多token标签
if conclusion in ["无害", "有害"]: # 任务1标签
return conclusion
elif conclusion.isdigit() and 0 <= int(conclusion) <= 4: # 任务2标签
return conclusion
# 尝试分词匹配
tokenizer = AutoProcessor.from_pretrained(args.model_id).tokenizer
conclusion_tokens = tokenizer.encode(conclusion, add_special_tokens=False)
# 与已知标签的token序列匹配
for label, tokens in self.label_token_ids.items():
if conclusion_tokens == tokens:
return label
return None
def compute_metrics(task_name, preds, labels):
"""计算评估指标"""
mask = labels != -100
preds = preds[mask]
labels = labels[mask]
if task_name == "task_1":
# 二分类任务
return {
"accuracy": accuracy_score(labels, preds),
"f1": f1_score(labels, preds, average="binary"),
"precision": precision_score(labels, preds, average="binary"),
"recall": recall_score(labels, preds, average="binary")
}
else:
# 多分类任务
report = classification_report(labels, preds, output_dict=True, zero_division=0)
return {
"accuracy": accuracy_score(labels, preds),
"f1_macro": f1_score(labels, preds, average="macro"),
"precision_macro": precision_score(labels, preds, average="macro"),
"recall_macro": recall_score(labels, preds, average="macro"),
"class_report": report
}
def main(args):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 1. 加载模型和预处理器
print("Loading model and processor...")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
quantization_config=quantization_config,
trust_remote_code=True,
device_map="auto",
bf16=True
)
model.generation_config = GenerationConfig.from_pretrained(
args.model_id,
trust_remote_code=True,
chat_format="chatml",
max_new_tokens=100,
pad_token_id=model.generation_config.eos_token_id
)
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
processor.chat_template = """{% for message in messages %}
<|im_start|>{{ message['role'] }}
{{ message['content'] }}
<|im_end|>
{% endfor %}
{% if add_generation_prompt %}
<|im_start|>assistant
{% endif %}"""
# 设置pad token
if processor.pad_token is None:
eod_token = processor.tokenizer.decode([processor.tokenizer.eot_token])
processor.pad_token = eod_token
print(f"已设置pad_token为:{processor.pad_token}")
# 2. LoRA配置
print("Configuring LoRA...")
lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"c_attn", "c_proj", "w1", "w2", "w3",
"visual.proj", "visual.image_encoder"
]
)
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()
# 3. 初始化优化器和调度器
optimizer = optim.AdamW(
peft_model.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay
)
# 4. 训练参数配置
training_args = TrainingArguments(
output_dir=os.path.join(args.output_dir, args.task),
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=args.grad_accum_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
lr_scheduler_type="cosine",
logging_strategy="steps",
logging_steps=10,
save_strategy="epoch",
eval_strategy="epoch",
eval_accumulation_steps=1,
metric_for_best_model="f1" if args.task == "task_1" else "f1_macro",
greater_is_better=True,
load_best_model_at_end=True,
bf16=True,
report_to="none",
remove_unused_columns=False,
disable_tqdm=False,
skip_memory_metrics=True,
dataloader_pin_memory=False,
)
# 5. 加载数据集
print(f"Loading datasets for {args.task}...")
train_dataset = HarmfulMemeDataset(
annotation_path=args.train_annotation_path,
processor=processor,
task_name=args.task,
image_root=args.image_root
)
test_dataset = HarmfulMemeDataset(
annotation_path=args.test_annotation_path,
processor=processor,
task_name=args.task,
image_root=args.image_root
)
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True
)
eval_loader = DataLoader(
test_dataset,
batch_size=args.eval_batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True
)
# 计算总步数,初始化学习率调度器
total_train_steps = len(train_loader) // args.grad_accum_steps * args.epochs
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_train_steps
)
# 6. 训练循环
print(f"Starting {args.task} training...")
best_metric = -1
for epoch in range(args.epochs):
print(f"\n===== Epoch {epoch + 1}/{args.epochs} =====")
# 训练阶段
peft_model.train()
total_train_loss = 0.0
train_pbar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}", unit="batch")
for step, batch in enumerate(train_pbar):
batch = {k: v.to(peft_model.device) for k, v in batch.items()}
# 前向传播
outputs = peft_model(**batch)
loss = outputs.loss
total_train_loss += loss.item()
# 梯度累积
loss = loss / args.grad_accum_steps
loss.backward()
# 参数更新
if (step + 1) % args.grad_accum_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# 更新进度条
train_pbar.set_postfix({"loss": f"{loss.item() * args.grad_accum_steps:.4f}"})
avg_train_loss = total_train_loss / len(train_loader)
print(f"Epoch {epoch + 1} 平均训练损失: {avg_train_loss:.4f}")
# 评估阶段
peft_model.eval()
all_preds = []
all_labels = []
all_generated_texts = []
eval_pbar = tqdm(eval_loader, desc=f"Evaluating Epoch {epoch + 1}", unit="batch")
with torch.no_grad():
for batch in eval_pbar:
# 获取真实标签
labels = batch["labels"].cpu().numpy()
mask = labels != -100
valid_labels = labels[mask].reshape(-1)
# 生成文本
inputs = {k: v.to(peft_model.device) for k, v in batch.items() if k != "labels"}
generated_ids = peft_model.generate(
**inputs,
generation_config=model.generation_config,
pad_token_id = processor.text_processor.tokenizer.pad_token_id
if hasattr(processor.text_processor.tokenizer, 'pad_token_id')
else processor.text_processor.tokenizer.eos_token_id
)
# 解码生成的文本
generated_texts = processor.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
# 解析生成的文本获取预测标签
batch_preds = []
for text in generated_texts:
# 提取assistant的响应部分
if "<|im_start|>assistant" in text:
response = text.split("<|im_start|>assistant")[-1].strip()
else:
response = text
# 解析结论
conclusion = parse_generated_text(response)
if conclusion is None:
# 无法解析结论,使用默认值
pred_label = 0 if args.task == "task_1" else "0"
else:
pred_label = conclusion
# 转换为数字标签
if args.task == "task_1":
# 二分类任务
if "无害" in pred_label:
pred_value = 0
elif "有害" in pred_label:
pred_value = 1
else:
# 无法解析,使用默认值
pred_value = 0
else:
# 多分类任务
if pred_label in ["0", "1", "2", "3", "4"]:
pred_value = int(pred_label)
else:
# 无法解析,使用默认值
pred_value = 0
batch_preds.append(pred_value)
all_preds.extend(batch_preds)
all_labels.extend(valid_labels.tolist())
all_generated_texts.extend(generated_texts)
# 计算评估指标
metrics = compute_metrics(args.task, np.array(all_preds), np.array(all_labels))
# 打印评估结果
print("\n评估指标:")
print("=" * 50)
if args.task == "task_1":
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"F1 Score: {metrics['f1']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
else:
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Macro F1: {metrics['f1_macro']:.4f}")
print(f"Macro Precision: {metrics['precision_macro']:.4f}")
print(f"Macro Recall: {metrics['recall_macro']:.4f}")
print("\n分类报告:")
print(classification_report(all_labels, all_preds, target_names=list(TASK2_LABEL_MAP.values()), zero_division=0))
print("=" * 50)
# 保存最佳模型
current_metric = metrics["f1"] if args.task == "task_1" else metrics["f1_macro"]
if current_metric > best_metric:
best_metric = current_metric
save_path = os.path.join(training_args.output_dir, f"best_model_epoch{epoch+1}")
print(f"保存最佳模型(指标 {current_metric:.4f})到 {save_path}")
peft_model.save_pretrained(save_path)
# 保存生成的文本示例
sample_output_path = os.path.join(save_path, "sample_outputs.txt")
with open(sample_output_path, "w", encoding="utf-8") as f:
for i, text in enumerate(all_generated_texts[:10]):
f.write(f"样本 {i+1}:\n")
f.write(text)
f.write("\n" + "-"*80 + "\n")
print(f"训练完成!最佳指标: {best_metric:.4f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="训练有害模因检测模型")
parser.add_argument("--model_id", default="/xzwu/Qwen-VL-Chat", help="预训练模型路径")
parser.add_argument("--output_dir", default="/xzwu/explain-m3-adapter", help="输出目录")
parser.add_argument("--epochs", type=int, default=5, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=4, help="训练批次大小")
parser.add_argument("--eval_batch_size", type=int, default=4, help="评估批次大小")
parser.add_argument("--grad_accum_steps", type=int, default=2, help="梯度累积步数")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="学习率")
parser.add_argument("--weight_decay", type=float, default=0.01, help="权重衰减")
parser.add_argument("--warmup_steps", type=int, default=100, help="预热步数")
parser.add_argument("--lora_rank", type=int, default=8, help="LoRA秩")
parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha")
parser.add_argument("--lora_dropout", type=float, default=0.1, help="LoRA dropout")
parser.add_argument("--num_workers", type=int, default=4, help="数据加载工作线程数")
parser.add_argument("--task", choices=["task_1", "task_2"], default="task_1", help="任务类型")
parser.add_argument("--train_annotation_path", default="/xzwu/data/data/train_data_explanation.json", help="训练标注路径")
parser.add_argument("--test_annotation_path", default="/xzwu/data/data/test_data_explanation.json", help="测试标注路径")
parser.add_argument("--image_root", default="/xzwu/data/meme", help="图片根目录")
args = parser.parse_args()
# 打印配置
print("=" * 50)
print("训练配置:")
for arg in vars(args):
print(f"{arg}: {getattr(args, arg)}")
print("=" * 50)
main(args)运行以上代码报错:Traceback (most recent call last):
File "/xzwu/explain-m3/explain-m3-project/train2.py", line 513, in <module>
main(args)
File "/xzwu/explain-m3/explain-m3-project/train2.py", line 392, in main
if hasattr(processor.text_processor.tokenizer, 'pad_token_id')
AttributeError: 'QWenTokenizer' object has no attribute 'text_processor'
最新发布