TRL - Transformer Reinforcement Learning 传递给SFTTrainer的数据集
flyfish
以json格式为例
步骤1:本地JSON文件
apple_data.json
内容如下
[
{
"instruction": "介绍苹果的基本分类(按成熟季节)",
"input": "苹果的季节分类",
"output": "苹果按成熟季节可分为三类:①早熟品种(7-8月成熟),如嘎啦、美八,口感偏酸脆;②中熟品种(9-10月成熟),如红富士、秦冠,甜度较高,耐储存;③晚熟品种(11月后成熟),如粉红女士,果肉紧实,风味浓郁,适合长期保存。"
},
{
"instruction": "用拟人化手法描述一个苹果",
"input": "一个红苹果",
"output": "我是一个圆滚滚的红苹果,穿着半红半黄的“外套”,表面光滑得像打了蜡。头顶的果柄是我的“小辫子”,轻轻一摸,还能感受到皮肤下饱满的汁水在悄悄晃动。凑近闻,一股清甜的香气会钻进你的鼻子——那是我在邀请你咬一口呀!"
},
{
"instruction": "解答关于苹果储存的问题",
"input": "如何延长苹果的保鲜期?",
"output": "延长苹果保鲜期的方法有:①冷藏储存:放入冰箱冷藏室(温度0-4℃),用保鲜袋包裹(留小孔透气),可保存1-2个月;②避免混放:与香蕉、梨等释放乙烯的水果分开,防止加速成熟;③阴凉通风处:常温下置于阴凉干燥处,可保存1-2周,适合短期食用。"
},
{
"instruction": "列举苹果的3种常见食用方法(非直接生食)",
"input": "苹果的加工吃法",
"output": "苹果的非生食做法有:①苹果派:将苹果切丁,加黄油、糖炒软,与面粉、鸡蛋等制作的派皮烘焙,外酥内软;②苹果银耳汤:苹果切块与银耳、红枣同煮,加冰糖调味,润肺养颜;③烤苹果:苹果去核,填入蜂蜜和坚果,烤箱烤制,果香浓郁,适合秋冬食用。"
},
{
"instruction": "对比苹果与梨的核心区别(从口感和营养侧重)",
"input": "苹果和梨的区别",
"output": "从口感看:苹果果肉多脆嫩(除面苹果外),果皮较薄;梨果肉偏细腻多汁,部分品种(如酥梨)果皮更光滑。从营养侧重看:苹果膳食纤维(尤其是果胶)含量更高,有助于肠道蠕动;梨的水分和梨醇含量更丰富,润肺生津效果更突出,适合干燥季节食用。"
}
]
步骤2:用 load_dataset
加载本地JSON文件
使用 load_dataset
时,指定数据格式为 'json'
(因为是JSON文件),并通过 data_files
参数指定本地文件路径。
from datasets import load_dataset
# 加载本地JSON文件
dataset = load_dataset(
path='json', # 数据格式为JSON
data_files='apple_data.json' # 本地JSON文件路径(若不在当前目录,需写绝对路径,如'./data/apple_data.json')
)
# 查看加载结果
print(dataset) # 打印数据集结构
print("\n第一条数据示例:")
print(dataset['train'][0]) # 打印第一条数据
输出结果
运行后会得到一个 DatasetDict
对象,默认拆分名为 'train'
(因为JSON文件中没有明确拆分,load_dataset
会默认将所有数据归为 'train'
拆分)。
Generating train split: 5 examples [00:00, 766.53 examples/s]
DatasetDict({
train: Dataset({
features: ['instruction', 'input', 'output'],
num_rows: 5
})
})
第一条数据示例:
{'instruction': '介绍苹果的基本分类(按成熟季节)', 'input': '苹果的季节分类', 'output': '苹果按成熟季节可分为三类:①早熟品种(7-8月成熟),如嘎啦、美八,口感偏酸脆;②中熟品种(9-10月成熟),如红富士、秦冠,甜度较高,耐储存;③晚熟品种(11月后成熟),如粉红女士,果肉紧实,风味浓郁,适合长期保存。'}
说明
- 数据格式要求:数据是“JSON数组”(
[{}, {}, ...]
),这是load_dataset
支持的标准格式,每个对象对应一条样本。 - 拆分处理:如果数据有多个拆分(如train/test),可以将不同拆分存为多个JSON文件(如
train.json
、test.json
),然后通过data_files={'train': 'train.json', 'test': 'test.json'}
指定。 - 后续使用:加载后的
dataset['train']
可直接作为SFTTrainer
的训练数据(需确保字段匹配,比如SFTTrainer
可能需要'text'
字段,此时需预处理拼接instruction+input+output
为'text'
字段)。
这样就完成了本地JSON数据的加载,后续可根据需求进行预处理(如字段拼接、分词等),再用于模型训练。
一个例子看具体数据集的变化
import argparse
import pprint
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
from trl import (
ModelConfig,
ScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
clone_chat_template,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
def main(script_args, training_args, model_args):
################
# 模型与分词器初始化
################
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
# 创建模型
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
valid_image_text_architectures = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
if config.architectures and any(arch in valid_image_text_architectures for arch in config.architectures):
from transformers import AutoModelForImageTextToText
model_kwargs.pop("use_cache", None)
model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,** model_kwargs)
# 创建分词器
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # 设置填充token
# 设置聊天模板
original_chat_template = tokenizer.chat_template
if tokenizer.chat_template is None:
print("应用默认聊天模板...")
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
print("\n===== 聊天模板信息 =====")
print(f"聊天模板: {tokenizer.chat_template[:200]}...") # 显示部分模板
################
# 数据集处理与跟踪
################
# 1. 加载原始数据集
print("\n" + "="*50)
print("阶段1: 加载原始数据集")
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
train_split = script_args.dataset_train_split
print(f"数据集拆分: {list(dataset.keys())}")
print(f"训练集样本数量: {len(dataset[train_split])}")
print("原始样本示例:")
pprint.pprint(dataset[train_split][0]) # 打印第一个原始样本
print("原始样本字段: ", dataset[train_split].column_names)
print("="*50)
# 2. 解析对话内容
print("\n阶段2: 解析对话内容")
# 提取对话轮次信息
sample = dataset[train_split][0]
print(f"对话轮次数量: {sample['num_turns']}")
print("对话角色序列: " + " → ".join([msg['role'] for msg in sample['messages']]))
# 展示前2轮对话内容
print("\n前2轮对话内容:")
for i, msg in enumerate(sample['messages'][:2]):
print(f"轮次{i+1} ({msg['role']}): {msg['content'][:100]}...")
print("="*50)
# 3. 应用聊天模板格式化
print("\n阶段3: 应用聊天模板格式化对话")
# 取前2个样本演示
demo_samples = [dataset[train_split][i] for i in range(min(2, len(dataset[train_split])))]
# 格式化对话
formatted_texts = []
for sample in demo_samples:
# 使用tokenizer的聊天模板格式化多轮对话
formatted = tokenizer.apply_chat_template(
sample['messages'],
tokenize=False,
add_generation_prompt=False # 不添加生成提示,因为这是训练数据
)
formatted_texts.append(formatted)
print(f"\n格式化后的对话样本 (前300字符):\n{formatted[:300]}...")
print("="*50)
# 4. Tokenization处理
print("\n阶段4: 文本Tokenization")
# 修正:将 max_seq_length 改为 max_length(SFTConfig 中正确的参数名)
print(f"使用参数: max_length={training_args.max_length}, padding='max_length', truncation=True")
def preprocess_function(examples):
# 应用聊天模板格式化所有对话
texts = [
tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
for msgs in examples['messages']
]
# Tokenize处理:同样使用 max_length
tokenized = tokenizer(
texts,
padding="max_length",
truncation=True,
max_length=training_args.max_length, # 修正参数名
return_overflowing_tokens=False,
)
# 生成labels(填充部分标记为-100,不参与损失计算)
tokenized["labels"] = [
[label if mask == 1 else -100 for label, mask in zip(input_ids, attention_mask)]
for input_ids, attention_mask in zip(tokenized["input_ids"], tokenized["attention_mask"])
]
return tokenized
# 处理少量样本用于演示
demo_dataset = dataset[train_split].select(range(min(2, len(dataset[train_split]))))
processed_dataset = demo_dataset.map(preprocess_function, batched=True, remove_columns=demo_dataset.column_names)
# 展示处理后的数据格式
print("\n处理后样本结构:")
pprint.pprint(processed_dataset[0].keys()) # 显示字段: input_ids, attention_mask, labels
print("\n处理后样本示例:")
for i in range(len(processed_dataset)):
print(f"\n样本{i+1}详情:")
print(f"input_ids (前20个): {processed_dataset[i]['input_ids'][:20]}")
print(f"attention_mask (前20个): {processed_dataset[i]['attention_mask'][:20]}")
print(f"labels (前20个): {processed_dataset[i]['labels'][:20]}")
print(f"序列总长度: {len(processed_dataset[i]['input_ids'])}") # 应等于 training_args.max_length
print("="*50)
# 5. 转换为模型输入格式(张量)
print("\n阶段5: 转换为模型输入格式")
# 转换为PyTorch张量
processed_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
print(f"转换后数据类型: {type(processed_dataset[0]['input_ids'])}")
print(f"张量形状: {processed_dataset[0]['input_ids'].shape}") # 单样本形状: (max_length,)
print("="*50)
################
# 训练过程
################
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=processed_dataset, # 使用处理后的数据集
eval_dataset=dataset[script_args.dataset_test_split].select(range(min(2, len(dataset[script_args.dataset_test_split]))))
if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)
print("\n开始训练(仅演示,实际训练可取消注释)")
# trainer.train()
# 保存模型
# trainer.save_model(training_args.output_dir)
# if training_args.push_to_hub:
# trainer.push_to_hub(dataset_name=script_args.dataset_name)
def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)
if subparsers is not None:
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
else:
parser = TrlParser(dataclass_types)
return parser
if __name__ == "__main__":
parser = make_parser()
script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True)
main(script_args, training_args, model_args)
输入命令
python sft.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--learning_rate 2.0e-4 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--eos_token '<|im_end|>' \
--eval_strategy steps \
--eval_steps 100 \
--use_peft \
--lora_r 32 \
--lora_alpha 16 \
--output_dir Qwen2-0.5B-SFT
输出
===== 聊天模板信息 =====
聊天模板: {% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
You are a helpful assistant<|im_end|>
' }}{% endif %}{{'<|im_start|>' + message['role'] + '
'...
==================================================
阶段1: 加载原始数据集
数据集拆分: ['train', 'test']
训练集样本数量: 15806
原始样本示例:
{'messages': [{'content': 'Recommend a movie to watch.\n', 'role': 'user'},
{'content': 'I would recommend the movie, "The Shawshank '
'Redemption" which is a classic drama film starring '
'Tim Robbins and Morgan Freeman. This film tells a '
'powerful story about hope and resilience, as it '
'follows the story of a young man who is wrongfully '
'convicted of murder and sent to prison. Amidst the '
'harsh realities of prison life, the protagonist '
'forms a bond with a fellow inmate, and together '
'they navigate the challenges of incarceration, '
'while holding on to the hope of eventual freedom. '
'This timeless movie is a must-watch for its moving '
'performances, uplifting message, and unforgettable '
'storytelling.',
'role': 'assistant'},
{'content': "Describe the character development of Tim Robbins' "
'character in "The Shawshank Redemption".',
'role': 'user'},
{'content': 'In "The Shawshank Redemption", Tim Robbins plays '
'the character of Andy Dufresne, a banker who is '
'wrongfully convicted of murdering his wife and her '
'lover. Over the course of the film, we see a '
"significant transformation in Andy's character.\n"
'\n'
'At the beginning of the movie, Andy is a quiet, '
'reserved man who maintains his innocence but is '
'largely passive in the face of his unjust '
'conviction. He seems almost aloof, detached from '
'the harsh reality of his situation. However, '
'beneath this seemingly impassive exterior, Andy is '
'a man of deep intelligence and inner strength.\n'
'\n'
'As the story progresses, Andy begins to adapt to '
'his new environment. He uses his banking skills to '
'gain favor with the prison guards and the warden, '
'providing them with financial advice and even '
'helping them with tax evasion. This not only earns '
'him certain privileges but also gives him a measure '
'of protection within the prison walls.\n'
'\n'
'Despite the grim circumstances, Andy never loses '
'hope. He befriends a fellow inmate, Red, and shares '
'with him his dream of living a quiet life in '
'Zihuatanejo, a small town on the Pacific coast of '
'Mexico. This dream keeps him going and becomes a '
'symbol of hope for other inmates as well.\n'
'\n'
"Perhaps the most significant development in Andy's "
'character comes towards the end of the film, when '
'it is revealed that he has been digging a tunnel '
'through his cell wall for the past 19 years. This '
'act of defiance, carried out with nothing but a '
"small rock hammer, is a testament to Andy's "
'resilience and determination. It also shows his '
'cunning, as he manages to keep his plan a secret '
'from everyone, including Red.\n'
'\n'
"In the end, Andy's character arc is one of hope, "
'resilience, and redemption. Despite being placed in '
'a situation that seems hopeless, he never gives up. '
'He maintains his dignity, uses his intelligence to '
'his advantage, and ultimately, manages to escape '
'from Shawshank, proving his innocence in the '
'process. His character serves as a powerful symbol '
"of the human spirit's ability to endure and "
'overcome even the most challenging circumstances.',
'role': 'assistant'},
{'content': 'Explain the significance of the friendship between '
"Andy and Red in shaping Andy's character "
'development.',
'role': 'user'},
{'content': 'The friendship between Andy Dufresne (Tim Robbins) '
'and Ellis "Red" Redding (Morgan Freeman) is central '
'to the narrative of "The Shawshank Redemption", and '
"plays a significant role in shaping Andy's "
'character development.\n'
'\n'
'When Andy first arrives at Shawshank, he is a '
'quiet, introverted man, largely keeping to himself. '
'Red, who is the long-term inmate, takes an interest '
"in him and they slowly develop a friendship. Red's "
'wisdom, street-smart attitude, and his ability to '
"procure items become instrumental in Andy's "
'survival and eventual escape.\n'
'\n'
"Red is initially skeptical of Andy's innocence and "
'his hopeful outlook on life. However, as their '
'friendship grows, Red becomes more receptive to '
"Andy's perspective. This friendship provides Andy "
'with a confidant, a sounding board, and a '
'supportive ally amidst the harsh realities of '
'prison life. \n'
'\n'
"Andy's influence on Red is equally profound. Andy's "
'unyielding hope and resilience slowly chip away at '
"Red's hardened cynicism. Andy shares his dreams of "
'freedom and his plans for the future with Red, '
'which initially seem unrealistic to Red, but over '
"time, Andy's unwavering belief in hope begins to "
"influence Red's outlook on life.\n"
'\n'
'In many ways, their friendship is a beacon of hope '
'and humanity in an otherwise oppressive '
"environment. It's through this friendship that Andy "
'finds the strength to maintain his dignity, '
'persevere, and ultimately, to engineer his daring '
"escape. It's also through this friendship that Red "
'finds hope for redemption and a life beyond the '
'prison walls.\n'
'\n'
'In conclusion, the friendship between Andy and Red '
"is a pivotal element in shaping Andy's character "
"development. It's through this bond that Andy finds "
'the strength to endure his unjust imprisonment and '
'to hold onto hope, ultimately leading to his '
'redemption.',
'role': 'assistant'}],
'num_turns': 6,
'source': 'GPT4LLM'}
原始样本字段: ['source', 'messages', 'num_turns']
==================================================
阶段2: 解析对话内容
对话轮次数量: 6
对话角色序列: user → assistant → user → assistant → user → assistant
前2轮对话内容:
轮次1 (user): Recommend a movie to watch.
...
轮次2 (assistant): I would recommend the movie, "The Shawshank Redemption" which is a classic drama film starring Tim R...
==================================================
阶段3: 应用聊天模板格式化对话
格式化后的对话样本 (前300字符):
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Recommend a movie to watch.
<|im_end|>
<|im_start|>assistant
I would recommend the movie, "The Shawshank Redemption" which is a classic drama film starring Tim Robbins and Morgan Freeman. This film tells a powerful story about...
格式化后的对话样本 (前300字符):
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Determine the result obtained by evaluating 5338245-50629795848152. Numbers and symbols only, please.<|im_end|>
<|im_start|>assistant
5338245 - 50629795848152 = -50629790509907<|im_end|>
...
==================================================
阶段4: 文本Tokenization
使用参数: max_length=1024, padding='max_length', truncation=True
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 77.97 examples/s]
处理后样本结构:
dict_keys(['input_ids', 'attention_mask', 'labels'])
处理后样本示例:
样本1详情:
input_ids (前20个): [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 67644, 264, 5700, 311, 3736, 624, 151645]
attention_mask (前20个): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
labels (前20个): [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 67644, 264, 5700, 311, 3736, 624, 151645]
序列总长度: 1024
样本2详情:
input_ids (前20个): [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 35, 24308, 279, 1102, 12180, 553, 37563]
attention_mask (前20个): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
labels (前20个): [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 35, 24308, 279, 1102, 12180, 553, 37563]
序列总长度: 1024
==================================================
阶段5: 转换为模型输入格式
转换后数据类型: <class 'torch.Tensor'>
张量形状: torch.Size([1024])
==================================================
从人类可理解的交互文本,逐步转换为模型可计算的张量
阶段1:原始数据集加载——保留原始交互结构
- 输入:数据集文件通过
load_dataset
加载。 - 处理逻辑:从数据源读取原始数据,不做任何格式修改,保留所有字段和原始交互信息。
- 输出格式:
Dataset
对象,每个样本为包含多字段的字典,核心字段包括:messages
:列表,每个元素是单轮对话({"role": "user"/"assistant", "content": "对话内容"}
)。num_turns
:整数,表示对话总轮次(如6轮)。source
:字符串,标识数据来源(如"GPT4LLM")。
- 示例(数据):
{ "source": "GPT4LLM", "num_turns": 6, "messages": [ {"role": "user", "content": "Recommend a movie to watch.\n"}, {"role": "assistant", "content": "I would recommend the movie..."}, # 更多轮次... ] }
阶段2:对话内容解析——提取核心交互信息
- 输入:原始数据集样本(阶段1的输出)。
- 处理逻辑:从原始数据中提取关键交互信息,理解对话结构(为后续格式化做准备)。
- 输出内容:
- 对话轮次:通过
num_turns
获取(如6轮)。 - 角色序列:提取
messages
中role
的顺序(如user → assistant → user → assistant...
)。 - 原始对话内容:提取每轮
content
的文本(如用户的电影推荐请求、助手的回复等)。
- 对话轮次:通过
- 作用:明确数据的交互逻辑(用户提问→助手回答的交替),为后续统一格式奠定基础。
阶段3:聊天模板格式化——统一模型输入格式
- 输入:阶段2解析的多轮
messages
(角色+内容)。 - 处理逻辑:使用分词器的
apply_chat_template
方法,将零散的messages
转换为模型可识别的统一字符串格式。- 核心逻辑:按模型预定义的模板(如Qwen的模板),为每轮对话添加角色标记(如
<|user|>
、<|assistant|>
)和分隔符(如换行、特殊token),确保上下文连贯性。
- 核心逻辑:按模型预定义的模板(如Qwen的模板),为每轮对话添加角色标记(如
- 输出格式:单字符串,包含完整对话历史,带角色标记和结构分隔。
- 示例(基于数据):
"<|user|>Recommend a movie to watch.\n<|assistant|>I would recommend the movie, "The Shawshank Redemption"...<|user|>Describe the character development of Tim Robbins' character...<|assistant|>In "The Shawshank Redemption", Tim Robbins plays..."
- 作用:让模型通过统一的格式识别“谁在说话”,明确上下文的角色边界(否则模型无法区分用户和助手的话)。
阶段4:Tokenization处理——文本→数字序列
- 输入:阶段3生成的格式化对话字符串。
- 处理逻辑:使用分词器(
tokenizer
)将文本转换为模型能理解的数字序列,核心步骤包括:- 分词:将字符串拆分为模型词汇表中的“最小语义单位”(token),如“Shawshank”→
3456
。 - 长度统一:根据
training_args.max_length
(如512),对序列进行截断(超长)或填充(不足),填充使用pad_token
(通常为eos_token
)。 - 生成辅助字段:
input_ids
:token对应的数字ID序列(模型的核心输入)。attention_mask
:0/1序列(1表示有效token,0表示填充的无效token)。labels
:用于计算损失的序列(与input_ids
一致,但填充位置标记为-100
,避免模型学习填充内容)。
- 分词:将字符串拆分为模型词汇表中的“最小语义单位”(token),如“Shawshank”→
- 输出格式:字典,包含3个核心字段(均为列表):
{ "input_ids": [101, 2345, 678, ..., 0, 0], # 长度=max_length,0为填充 "attention_mask": [1, 1, 1, ..., 0, 0], # 有效token标记 "labels": [101, 2345, 678, ..., -100, -100] # 填充位置忽略损失 }
- 作用:将人类可理解的文本转换为模型可计算的数字序列,同时通过
attention_mask
和labels
告诉模型“关注什么”和“学习什么”。
阶段5:转换为模型输入格式——张量化适配批处理
- 输入:阶段4生成的
input_ids
、attention_mask
、labels
列表。 - 处理逻辑:将列表格式转换为PyTorch张量(
torch.Tensor
),并统一形状。 - 输出格式:张量字典,每个字段的形状为
(max_length,)
(单样本)或(batch_size, max_length)
(批次)。{ "input_ids": tensor([101, 2345, 678, ..., 0, 0]), # 形状: (max_length,) "attention_mask": tensor([1, 1, 1, ..., 0, 0]), "labels": tensor([101, 2345, 678, ..., -100, -100]) }
- 作用:适配模型的输入要求(模型仅接受张量格式),便于批量计算(如并行处理多个样本)。
阶段 | 数据形态变化 | 核心目标 |
---|---|---|
原始加载 | 多字段字典(保留原始交互) | 完整保留数据原貌 |
对话解析 | 提取轮次、角色、内容 | 理解数据结构,为格式化做准备 |
模板格式化 | 多轮messages →统一带角色标记的字符串 | 让模型识别角色和上下文边界 |
Tokenization | 字符串→input_ids /mask /labels | 将文本转为模型可理解的数字序列 |
张量转换 | 列表→PyTorch张量 | 适配模型输入格式,支持批处理和反向传播 |
load_dataset
函数详解
load_dataset
是 Hugging Face datasets
库的核心函数,用于加载数据集,支持从 Hugging Face Hub(云端仓库)、本地文件 或 自定义格式 加载,兼容多种数据格式(CSV、JSON、Parquet、图片、音频等)。其核心作用是简化数据集的获取、处理和缓存流程,返回可直接用于模型训练/评估的 Dataset
或 DatasetDict
对象。
参数说明(表格)
参数名称 | 类型 | 默认值 | 描述 |
---|---|---|---|
path | str | 无(必填) | 数据集来源路径或名称,决定加载方式: - Hub 仓库名(如 'cornell-movie-review-data/rotten_tomatoes' ):从云端加载;- 本地目录(如 './data' ):从本地文件夹加载;- 数据格式(如 'csv' ):配合 data_files 加载指定格式文件。 |
name | str (可选) | None | 数据集配置名称(部分数据集有多个子配置,如 'nyu-mll/glue' 的 'sst2' 子任务)。 |
data_dir | str (可选) | None | 本地数据目录。若指定,且 data_files 为 None ,则加载该目录下所有文件(等效于 data_files=os.path.join(data_dir, **) )。 |
data_files | str / Sequence[str] / Mapping[str, Union[str, Sequence[str]]] (可选) | None | 具体数据文件路径: - 单文件(如 'train.csv' );- 多文件列表(如 ['train1.csv', 'train2.csv'] );- 拆分映射(如 {'train': 'train.csv', 'test': 'test.csv'} ),指定文件对应的数据拆分(train/test)。 |
split | str / Split (可选) | None | 加载的数据集拆分(如 'train' 、'test' 、'train+test' )。- 若为 None :返回包含所有拆分的 DatasetDict ;- 若指定:返回单个 Dataset 。 |
cache_dir | str (可选) | ~/.cache/huggingface/datasets | 缓存目录,用于存储下载/处理后的数据集(避免重复下载)。 |
features | Features (可选) | None | 自定义数据集特征结构(如指定字段类型为文本、整数等),覆盖自动推断的特征。 |
download_config | DownloadConfig (可选) | None | 下载配置(如超时时间、代理等),控制数据下载的细节。 |
download_mode | DownloadMode / str (可选) | REUSE_DATASET_IF_EXISTS | 下载模式: - REUSE_DATASET_IF_EXISTS (默认):若本地有缓存,直接复用;- FORCE_REDOWNLOAD :强制重新下载;- REUSE_CACHE_IF_EXISTS :复用缓存,若缓存损坏则重新下载。 |
verification_mode | VerificationMode / str (可选) | BASIC_CHECKS | 数据集校验模式: - BASIC_CHECKS (默认):基础校验(文件存在、大小匹配);- ALL_CHECKS :完整校验(含哈希值匹配);- NO_CHECKS :不校验。若 save_infos=True ,默认升级为 ALL_CHECKS 。 |
keep_in_memory | bool (可选) | None | 是否将数据集加载到内存: - None (默认):自动判断(小数据集加载到内存,大数据集从磁盘读取);- True /False :强制加载/不加载到内存。 |
save_infos | bool (可选) | False | 是否保存数据集元信息(如校验和、大小、拆分信息)到缓存目录。 |
revision | str / Version (可选) | None | Hub 数据集的版本(如分支名 'main' 、commit SHA、标签),用于加载特定版本的数据集。 |
token | bool / str (可选) | None | 访问私有 Hub 数据集的令牌: - True :从 ~/.huggingface 读取令牌;- 字符串:直接传入令牌。 |
streaming | bool (可选) | False | 是否流式加载: - False (默认):下载并缓存完整数据集;- True :不下载,实时流式读取(适合超大数据集,返回 IterableDataset )。 |
num_proc | int (可选) | None | 并行处理的进程数(用于加速本地数据集的下载和预处理),默认禁用多进程。 |
storage_options | dict (可选) | None | 实验性参数,传递给文件系统后端的配置(如云端存储的访问密钥)。 |
trust_remote_code | bool (可选) | False | 是否信任 Hub 上的数据集脚本(若数据集包含自定义脚本)。 - True :执行远程脚本(仅信任可靠仓库);- False (默认):拒绝执行远程脚本。 |
**config_kwargs | 额外关键字参数 | 无 | 传递给数据集构建器(DatasetBuilder )的额外配置参数。 |
返回值说明
- 若
streaming=False
(默认):
- 若split
为None
:返回DatasetDict
(包含所有数据拆分,如{'train': Dataset, 'test': Dataset}
);
- 若split
指定:返回单个Dataset
(对应拆分的数据)。 - 若
streaming=True
:
- 若split
为None
:返回IterableDatasetDict
;
- 若split
指定:返回IterableDataset
(流式迭代器,不缓存完整数据)。
核心特点
- 多来源支持:兼容 Hub 云端仓库、本地文件、自定义格式(图片、音频等)。
- 自动化处理:自动解析数据格式、推断特征类型、缓存处理结果(避免重复计算)。
- 灵活拆分:支持加载特定拆分(如仅训练集)或组合拆分(如
'train+validation'
)。 - 流式加载:适合超大数据集(不占用本地存储空间,实时读取)。
- 版本控制:通过
revision
参数指定 Hub 数据集的特定版本,确保可复现性。
示例场景
- 从 Hub 加载公开数据集:
load_dataset('rotten_tomatoes', split='train')
- 从本地 CSV 文件加载:
load_dataset('csv', data_files='./train.csv')
- 流式加载超大数据集:
load_dataset('large_dataset', streaming=True)
- 加载指定版本的数据集:
load_dataset('my_dataset', revision='v1.0', split='test')