TRL - Transformer Reinforcement Learning 传递给SFTTrainer的数据集

TRL - Transformer Reinforcement Learning 传递给SFTTrainer的数据集

flyfish

以json格式为例

步骤1:本地JSON文件

apple_data.json内容如下

[
  {
    "instruction": "介绍苹果的基本分类(按成熟季节)",
    "input": "苹果的季节分类",
    "output": "苹果按成熟季节可分为三类:①早熟品种(7-8月成熟),如嘎啦、美八,口感偏酸脆;②中熟品种(9-10月成熟),如红富士、秦冠,甜度较高,耐储存;③晚熟品种(11月后成熟),如粉红女士,果肉紧实,风味浓郁,适合长期保存。"
  },
  {
    "instruction": "用拟人化手法描述一个苹果",
    "input": "一个红苹果",
    "output": "我是一个圆滚滚的红苹果,穿着半红半黄的“外套”,表面光滑得像打了蜡。头顶的果柄是我的“小辫子”,轻轻一摸,还能感受到皮肤下饱满的汁水在悄悄晃动。凑近闻,一股清甜的香气会钻进你的鼻子——那是我在邀请你咬一口呀!"
  },
  {
    "instruction": "解答关于苹果储存的问题",
    "input": "如何延长苹果的保鲜期?",
    "output": "延长苹果保鲜期的方法有:①冷藏储存:放入冰箱冷藏室(温度0-4℃),用保鲜袋包裹(留小孔透气),可保存1-2个月;②避免混放:与香蕉、梨等释放乙烯的水果分开,防止加速成熟;③阴凉通风处:常温下置于阴凉干燥处,可保存1-2周,适合短期食用。"
  },
  {
    "instruction": "列举苹果的3种常见食用方法(非直接生食)",
    "input": "苹果的加工吃法",
    "output": "苹果的非生食做法有:①苹果派:将苹果切丁,加黄油、糖炒软,与面粉、鸡蛋等制作的派皮烘焙,外酥内软;②苹果银耳汤:苹果切块与银耳、红枣同煮,加冰糖调味,润肺养颜;③烤苹果:苹果去核,填入蜂蜜和坚果,烤箱烤制,果香浓郁,适合秋冬食用。"
  },
  {
    "instruction": "对比苹果与梨的核心区别(从口感和营养侧重)",
    "input": "苹果和梨的区别",
    "output": "从口感看:苹果果肉多脆嫩(除面苹果外),果皮较薄;梨果肉偏细腻多汁,部分品种(如酥梨)果皮更光滑。从营养侧重看:苹果膳食纤维(尤其是果胶)含量更高,有助于肠道蠕动;梨的水分和梨醇含量更丰富,润肺生津效果更突出,适合干燥季节食用。"
  }
]

步骤2:用 load_dataset 加载本地JSON文件

使用 load_dataset 时,指定数据格式为 'json'(因为是JSON文件),并通过 data_files 参数指定本地文件路径。

from datasets import load_dataset

# 加载本地JSON文件
dataset = load_dataset(
    path='json',  # 数据格式为JSON
    data_files='apple_data.json'  # 本地JSON文件路径(若不在当前目录,需写绝对路径,如'./data/apple_data.json')
)

# 查看加载结果
print(dataset)  # 打印数据集结构
print("\n第一条数据示例:")
print(dataset['train'][0])  # 打印第一条数据

输出结果

运行后会得到一个 DatasetDict 对象,默认拆分名为 'train'(因为JSON文件中没有明确拆分,load_dataset 会默认将所有数据归为 'train' 拆分)。

Generating train split: 5 examples [00:00, 766.53 examples/s]
DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output'],
        num_rows: 5
    })
})

第一条数据示例:
{'instruction': '介绍苹果的基本分类(按成熟季节)', 'input': '苹果的季节分类', 'output': '苹果按成熟季节可分为三类:①早熟品种(7-8月成熟),如嘎啦、美八,口感偏酸脆;②中熟品种(9-10月成熟),如红富士、秦冠,甜度较高,耐储存;③晚熟品种(11月后成熟),如粉红女士,果肉紧实,风味浓郁,适合长期保存。'}

说明

  1. 数据格式要求:数据是“JSON数组”([{}, {}, ...]),这是 load_dataset 支持的标准格式,每个对象对应一条样本。
  2. 拆分处理:如果数据有多个拆分(如train/test),可以将不同拆分存为多个JSON文件(如 train.jsontest.json),然后通过 data_files={'train': 'train.json', 'test': 'test.json'} 指定。
  3. 后续使用:加载后的 dataset['train'] 可直接作为 SFTTrainer 的训练数据(需确保字段匹配,比如 SFTTrainer 可能需要 'text' 字段,此时需预处理拼接 instruction+input+output'text' 字段)。

这样就完成了本地JSON数据的加载,后续可根据需求进行预处理(如字段拼接、分词等),再用于模型训练。

一个例子看具体数据集的变化

import argparse
import pprint

from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES

from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    clone_chat_template,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)


def main(script_args, training_args, model_args):
    ################
    # 模型与分词器初始化
    ################
    quantization_config = get_quantization_config(model_args)
    model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=model_args.torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )

    # 创建模型
    config = AutoConfig.from_pretrained(model_args.model_name_or_path)
    valid_image_text_architectures = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()

    if config.architectures and any(arch in valid_image_text_architectures for arch in config.architectures):
        from transformers import AutoModelForImageTextToText
        model_kwargs.pop("use_cache", None)
        model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,** model_kwargs)

    # 创建分词器
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # 设置填充token

    # 设置聊天模板
    original_chat_template = tokenizer.chat_template
    if tokenizer.chat_template is None:
        print("应用默认聊天模板...")
        model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
    
    print("\n===== 聊天模板信息 =====")
    print(f"聊天模板: {tokenizer.chat_template[:200]}...")  # 显示部分模板

    ################
    # 数据集处理与跟踪
    ################
    # 1. 加载原始数据集
    print("\n" + "="*50)
    print("阶段1: 加载原始数据集")
    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
    train_split = script_args.dataset_train_split
    print(f"数据集拆分: {list(dataset.keys())}")
    print(f"训练集样本数量: {len(dataset[train_split])}")
    print("原始样本示例:")
    pprint.pprint(dataset[train_split][0])  # 打印第一个原始样本
    print("原始样本字段: ", dataset[train_split].column_names)
    print("="*50)

    # 2. 解析对话内容
    print("\n阶段2: 解析对话内容")
    # 提取对话轮次信息
    sample = dataset[train_split][0]
    print(f"对话轮次数量: {sample['num_turns']}")
    print("对话角色序列: " + " → ".join([msg['role'] for msg in sample['messages']]))
    
    # 展示前2轮对话内容
    print("\n前2轮对话内容:")
    for i, msg in enumerate(sample['messages'][:2]):
        print(f"轮次{i+1} ({msg['role']}): {msg['content'][:100]}...")
    print("="*50)

    # 3. 应用聊天模板格式化
    print("\n阶段3: 应用聊天模板格式化对话")
    # 取前2个样本演示
    demo_samples = [dataset[train_split][i] for i in range(min(2, len(dataset[train_split])))]
    
    # 格式化对话
    formatted_texts = []
    for sample in demo_samples:
        # 使用tokenizer的聊天模板格式化多轮对话
        formatted = tokenizer.apply_chat_template(
            sample['messages'],
            tokenize=False,
            add_generation_prompt=False  # 不添加生成提示,因为这是训练数据
        )
        formatted_texts.append(formatted)
        print(f"\n格式化后的对话样本 (前300字符):\n{formatted[:300]}...")
    print("="*50)

    # 4. Tokenization处理
    print("\n阶段4: 文本Tokenization")
    # 修正:将 max_seq_length 改为 max_length(SFTConfig 中正确的参数名)
    print(f"使用参数: max_length={training_args.max_length}, padding='max_length', truncation=True")
    
    def preprocess_function(examples):
        # 应用聊天模板格式化所有对话
        texts = [
            tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
            for msgs in examples['messages']
        ]
        
        # Tokenize处理:同样使用 max_length
        tokenized = tokenizer(
            texts,
            padding="max_length",
            truncation=True,
            max_length=training_args.max_length,  # 修正参数名
            return_overflowing_tokens=False,
        )
        
        # 生成labels(填充部分标记为-100,不参与损失计算)
        tokenized["labels"] = [
            [label if mask == 1 else -100 for label, mask in zip(input_ids, attention_mask)]
            for input_ids, attention_mask in zip(tokenized["input_ids"], tokenized["attention_mask"])
        ]
        return tokenized

    # 处理少量样本用于演示
    demo_dataset = dataset[train_split].select(range(min(2, len(dataset[train_split]))))
    processed_dataset = demo_dataset.map(preprocess_function, batched=True, remove_columns=demo_dataset.column_names)
    
    # 展示处理后的数据格式
    print("\n处理后样本结构:")
    pprint.pprint(processed_dataset[0].keys())  # 显示字段: input_ids, attention_mask, labels
    print("\n处理后样本示例:")
    for i in range(len(processed_dataset)):
        print(f"\n样本{i+1}详情:")
        print(f"input_ids (前20个): {processed_dataset[i]['input_ids'][:20]}")
        print(f"attention_mask (前20个): {processed_dataset[i]['attention_mask'][:20]}")
        print(f"labels (前20个): {processed_dataset[i]['labels'][:20]}")
        print(f"序列总长度: {len(processed_dataset[i]['input_ids'])}")  # 应等于 training_args.max_length
    print("="*50)

    # 5. 转换为模型输入格式(张量)
    print("\n阶段5: 转换为模型输入格式")
    # 转换为PyTorch张量
    processed_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
    print(f"转换后数据类型: {type(processed_dataset[0]['input_ids'])}")
    print(f"张量形状: {processed_dataset[0]['input_ids'].shape}")  # 单样本形状: (max_length,)
    print("="*50)

    ################
    # 训练过程
    ################
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=processed_dataset,  # 使用处理后的数据集
        eval_dataset=dataset[script_args.dataset_test_split].select(range(min(2, len(dataset[script_args.dataset_test_split])))) 
            if training_args.eval_strategy != "no" else None,
        processing_class=tokenizer,
        peft_config=get_peft_config(model_args),
    )

    print("\n开始训练(仅演示,实际训练可取消注释)")
    # trainer.train()

    # 保存模型
    # trainer.save_model(training_args.output_dir)
    # if training_args.push_to_hub:
    #     trainer.push_to_hub(dataset_name=script_args.dataset_name)


def make_parser(subparsers: argparse._SubParsersAction = None):
    dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)
    if subparsers is not None:
        parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
    else:
        parser = TrlParser(dataclass_types)
    return parser


if __name__ == "__main__":
    parser = make_parser()
    script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True)
    main(script_args, training_args, model_args)

输入命令

python sft.py \
    --model_name_or_path Qwen/Qwen2-0.5B \
    --dataset_name trl-lib/Capybara \
    --learning_rate 2.0e-4 \
    --num_train_epochs 1 \
    --packing \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing \
    --eos_token '<|im_end|>' \
    --eval_strategy steps \
    --eval_steps 100 \
    --use_peft \
    --lora_r 32 \
    --lora_alpha 16 \
    --output_dir Qwen2-0.5B-SFT 

输出

===== 聊天模板信息 =====


聊天模板: {% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
You are a helpful assistant<|im_end|>
' }}{% endif %}{{'<|im_start|>' + message['role'] + '
'...

==================================================

阶段1: 加载原始数据集

数据集拆分: ['train', 'test']
训练集样本数量: 15806
原始样本示例:


{'messages': [{'content': 'Recommend a movie to watch.\n', 'role': 'user'},
              {'content': 'I would recommend the movie, "The Shawshank '
                          'Redemption" which is a classic drama film starring '
                          'Tim Robbins and Morgan Freeman. This film tells a '
                          'powerful story about hope and resilience, as it '
                          'follows the story of a young man who is wrongfully '
                          'convicted of murder and sent to prison. Amidst the '
                          'harsh realities of prison life, the protagonist '
                          'forms a bond with a fellow inmate, and together '
                          'they navigate the challenges of incarceration, '
                          'while holding on to the hope of eventual freedom. '
                          'This timeless movie is a must-watch for its moving '
                          'performances, uplifting message, and unforgettable '
                          'storytelling.',
               'role': 'assistant'},
              {'content': "Describe the character development of Tim Robbins' "
                          'character in "The Shawshank Redemption".',
               'role': 'user'},
              {'content': 'In "The Shawshank Redemption", Tim Robbins plays '
                          'the character of Andy Dufresne, a banker who is '
                          'wrongfully convicted of murdering his wife and her '
                          'lover. Over the course of the film, we see a '
                          "significant transformation in Andy's character.\n"
                          '\n'
                          'At the beginning of the movie, Andy is a quiet, '
                          'reserved man who maintains his innocence but is '
                          'largely passive in the face of his unjust '
                          'conviction. He seems almost aloof, detached from '
                          'the harsh reality of his situation. However, '
                          'beneath this seemingly impassive exterior, Andy is '
                          'a man of deep intelligence and inner strength.\n'
                          '\n'
                          'As the story progresses, Andy begins to adapt to '
                          'his new environment. He uses his banking skills to '
                          'gain favor with the prison guards and the warden, '
                          'providing them with financial advice and even '
                          'helping them with tax evasion. This not only earns '
                          'him certain privileges but also gives him a measure '
                          'of protection within the prison walls.\n'
                          '\n'
                          'Despite the grim circumstances, Andy never loses '
                          'hope. He befriends a fellow inmate, Red, and shares '
                          'with him his dream of living a quiet life in '
                          'Zihuatanejo, a small town on the Pacific coast of '
                          'Mexico. This dream keeps him going and becomes a '
                          'symbol of hope for other inmates as well.\n'
                          '\n'
                          "Perhaps the most significant development in Andy's "
                          'character comes towards the end of the film, when '
                          'it is revealed that he has been digging a tunnel '
                          'through his cell wall for the past 19 years. This '
                          'act of defiance, carried out with nothing but a '
                          "small rock hammer, is a testament to Andy's "
                          'resilience and determination. It also shows his '
                          'cunning, as he manages to keep his plan a secret '
                          'from everyone, including Red.\n'
                          '\n'
                          "In the end, Andy's character arc is one of hope, "
                          'resilience, and redemption. Despite being placed in '
                          'a situation that seems hopeless, he never gives up. '
                          'He maintains his dignity, uses his intelligence to '
                          'his advantage, and ultimately, manages to escape '
                          'from Shawshank, proving his innocence in the '
                          'process. His character serves as a powerful symbol '
                          "of the human spirit's ability to endure and "
                          'overcome even the most challenging circumstances.',
               'role': 'assistant'},
              {'content': 'Explain the significance of the friendship between '
                          "Andy and Red in shaping Andy's character "
                          'development.',
               'role': 'user'},
              {'content': 'The friendship between Andy Dufresne (Tim Robbins) '
                          'and Ellis "Red" Redding (Morgan Freeman) is central '
                          'to the narrative of "The Shawshank Redemption", and '
                          "plays a significant role in shaping Andy's "
                          'character development.\n'
                          '\n'
                          'When Andy first arrives at Shawshank, he is a '
                          'quiet, introverted man, largely keeping to himself. '
                          'Red, who is the long-term inmate, takes an interest '
                          "in him and they slowly develop a friendship. Red's "
                          'wisdom, street-smart attitude, and his ability to '
                          "procure items become instrumental in Andy's "
                          'survival and eventual escape.\n'
                          '\n'
                          "Red is initially skeptical of Andy's innocence and "
                          'his hopeful outlook on life. However, as their '
                          'friendship grows, Red becomes more receptive to '
                          "Andy's perspective. This friendship provides Andy "
                          'with a confidant, a sounding board, and a '
                          'supportive ally amidst the harsh realities of '
                          'prison life. \n'
                          '\n'
                          "Andy's influence on Red is equally profound. Andy's "
                          'unyielding hope and resilience slowly chip away at '
                          "Red's hardened cynicism. Andy shares his dreams of "
                          'freedom and his plans for the future with Red, '
                          'which initially seem unrealistic to Red, but over '
                          "time, Andy's unwavering belief in hope begins to "
                          "influence Red's outlook on life.\n"
                          '\n'
                          'In many ways, their friendship is a beacon of hope '
                          'and humanity in an otherwise oppressive '
                          "environment. It's through this friendship that Andy "
                          'finds the strength to maintain his dignity, '
                          'persevere, and ultimately, to engineer his daring '
                          "escape. It's also through this friendship that Red "
                          'finds hope for redemption and a life beyond the '
                          'prison walls.\n'
                          '\n'
                          'In conclusion, the friendship between Andy and Red '
                          "is a pivotal element in shaping Andy's character "
                          "development. It's through this bond that Andy finds "
                          'the strength to endure his unjust imprisonment and '
                          'to hold onto hope, ultimately leading to his '
                          'redemption.',
               'role': 'assistant'}],
 'num_turns': 6,
 'source': 'GPT4LLM'}


原始样本字段:  ['source', 'messages', 'num_turns']
==================================================

阶段2: 解析对话内容

对话轮次数量: 6
对话角色序列: user → assistant → user → assistant → user → assistant

前2轮对话内容:
轮次1 (user): Recommend a movie to watch.
...
轮次2 (assistant): I would recommend the movie, "The Shawshank Redemption" which is a classic drama film starring Tim R...
==================================================

阶段3: 应用聊天模板格式化对话

格式化后的对话样本 (前300字符):
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Recommend a movie to watch.
<|im_end|>
<|im_start|>assistant
I would recommend the movie, "The Shawshank Redemption" which is a classic drama film starring Tim Robbins and Morgan Freeman. This film tells a powerful story about...

格式化后的对话样本 (前300字符):
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Determine the result obtained by evaluating 5338245-50629795848152. Numbers and symbols only, please.<|im_end|>
<|im_start|>assistant
5338245 - 50629795848152 = -50629790509907<|im_end|>
...
==================================================

阶段4: 文本Tokenization

使用参数: max_length=1024, padding='max_length', truncation=True
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 77.97 examples/s]

处理后样本结构:
dict_keys(['input_ids', 'attention_mask', 'labels'])

处理后样本示例:

样本1详情:
input_ids (前20个): [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 67644, 264, 5700, 311, 3736, 624, 151645]
attention_mask (前20个): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
labels (前20个): [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 67644, 264, 5700, 311, 3736, 624, 151645]
序列总长度: 1024

样本2详情:
input_ids (前20个): [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 35, 24308, 279, 1102, 12180, 553, 37563]
attention_mask (前20个): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
labels (前20个): [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 35, 24308, 279, 1102, 12180, 553, 37563]
序列总长度: 1024
==================================================

阶段5: 转换为模型输入格式
转换后数据类型: <class 'torch.Tensor'>
张量形状: torch.Size([1024])
==================================================

从人类可理解的交互文本,逐步转换为模型可计算的张量

阶段1:原始数据集加载——保留原始交互结构

  • 输入:数据集文件通过load_dataset加载。
  • 处理逻辑:从数据源读取原始数据,不做任何格式修改,保留所有字段和原始交互信息。
  • 输出格式Dataset对象,每个样本为包含多字段的字典,核心字段包括:
    • messages:列表,每个元素是单轮对话({"role": "user"/"assistant", "content": "对话内容"})。
    • num_turns:整数,表示对话总轮次(如6轮)。
    • source:字符串,标识数据来源(如"GPT4LLM")。
  • 示例(数据)
    {
        "source": "GPT4LLM",
        "num_turns": 6,
        "messages": [
            {"role": "user", "content": "Recommend a movie to watch.\n"},
            {"role": "assistant", "content": "I would recommend the movie..."},
            # 更多轮次...
        ]
    }
    

阶段2:对话内容解析——提取核心交互信息

  • 输入:原始数据集样本(阶段1的输出)。
  • 处理逻辑:从原始数据中提取关键交互信息,理解对话结构(为后续格式化做准备)。
  • 输出内容
    • 对话轮次:通过num_turns获取(如6轮)。
    • 角色序列:提取messagesrole的顺序(如user → assistant → user → assistant...)。
    • 原始对话内容:提取每轮content的文本(如用户的电影推荐请求、助手的回复等)。
  • 作用:明确数据的交互逻辑(用户提问→助手回答的交替),为后续统一格式奠定基础。

阶段3:聊天模板格式化——统一模型输入格式

  • 输入:阶段2解析的多轮messages(角色+内容)。
  • 处理逻辑:使用分词器的apply_chat_template方法,将零散的messages转换为模型可识别的统一字符串格式
    • 核心逻辑:按模型预定义的模板(如Qwen的模板),为每轮对话添加角色标记(如<|user|><|assistant|>)和分隔符(如换行、特殊token),确保上下文连贯性。
  • 输出格式:单字符串,包含完整对话历史,带角色标记和结构分隔。
  • 示例(基于数据)
    "<|user|>Recommend a movie to watch.\n<|assistant|>I would recommend the movie, "The Shawshank Redemption"...<|user|>Describe the character development of Tim Robbins' character...<|assistant|>In "The Shawshank Redemption", Tim Robbins plays..."
    
  • 作用:让模型通过统一的格式识别“谁在说话”,明确上下文的角色边界(否则模型无法区分用户和助手的话)。

阶段4:Tokenization处理——文本→数字序列

  • 输入:阶段3生成的格式化对话字符串。
  • 处理逻辑:使用分词器(tokenizer)将文本转换为模型能理解的数字序列,核心步骤包括:
    1. 分词:将字符串拆分为模型词汇表中的“最小语义单位”(token),如“Shawshank”→3456
    2. 长度统一:根据training_args.max_length(如512),对序列进行截断(超长)或填充(不足),填充使用pad_token(通常为eos_token)。
    3. 生成辅助字段
      • input_ids:token对应的数字ID序列(模型的核心输入)。
      • attention_mask:0/1序列(1表示有效token,0表示填充的无效token)。
      • labels:用于计算损失的序列(与input_ids一致,但填充位置标记为-100,避免模型学习填充内容)。
  • 输出格式:字典,包含3个核心字段(均为列表):
    {
        "input_ids": [101, 2345, 678, ..., 0, 0],  # 长度=max_length,0为填充
        "attention_mask": [1, 1, 1, ..., 0, 0],    # 有效token标记
        "labels": [101, 2345, 678, ..., -100, -100]  # 填充位置忽略损失
    }
    
  • 作用:将人类可理解的文本转换为模型可计算的数字序列,同时通过attention_masklabels告诉模型“关注什么”和“学习什么”。

阶段5:转换为模型输入格式——张量化适配批处理

  • 输入:阶段4生成的input_idsattention_masklabels列表。
  • 处理逻辑:将列表格式转换为PyTorch张量(torch.Tensor),并统一形状。
  • 输出格式:张量字典,每个字段的形状为(max_length,)(单样本)或(batch_size, max_length)(批次)。
    {
        "input_ids": tensor([101, 2345, 678, ..., 0, 0]),  # 形状: (max_length,)
        "attention_mask": tensor([1, 1, 1, ..., 0, 0]),
        "labels": tensor([101, 2345, 678, ..., -100, -100])
    }
    
  • 作用:适配模型的输入要求(模型仅接受张量格式),便于批量计算(如并行处理多个样本)。
阶段数据形态变化核心目标
原始加载多字段字典(保留原始交互)完整保留数据原貌
对话解析提取轮次、角色、内容理解数据结构,为格式化做准备
模板格式化多轮messages→统一带角色标记的字符串让模型识别角色和上下文边界
Tokenization字符串→input_ids/mask/labels将文本转为模型可理解的数字序列
张量转换列表→PyTorch张量适配模型输入格式,支持批处理和反向传播

load_dataset 函数详解

load_dataset 是 Hugging Face datasets 库的核心函数,用于加载数据集,支持从 Hugging Face Hub(云端仓库)本地文件自定义格式 加载,兼容多种数据格式(CSV、JSON、Parquet、图片、音频等)。其核心作用是简化数据集的获取、处理和缓存流程,返回可直接用于模型训练/评估的 DatasetDatasetDict 对象。

参数说明(表格)

参数名称类型默认值描述
pathstr无(必填)数据集来源路径或名称,决定加载方式:
- Hub 仓库名(如 'cornell-movie-review-data/rotten_tomatoes'):从云端加载;
- 本地目录(如 './data'):从本地文件夹加载;
- 数据格式(如 'csv'):配合 data_files 加载指定格式文件。
namestr(可选)None数据集配置名称(部分数据集有多个子配置,如 'nyu-mll/glue''sst2' 子任务)。
data_dirstr(可选)None本地数据目录。若指定,且 data_filesNone,则加载该目录下所有文件(等效于 data_files=os.path.join(data_dir, **))。
data_filesstr / Sequence[str] / Mapping[str, Union[str, Sequence[str]]](可选)None具体数据文件路径:
- 单文件(如 'train.csv');
- 多文件列表(如 ['train1.csv', 'train2.csv']);
- 拆分映射(如 {'train': 'train.csv', 'test': 'test.csv'}),指定文件对应的数据拆分(train/test)。
splitstr / Split(可选)None加载的数据集拆分(如 'train''test''train+test')。
- 若为 None:返回包含所有拆分的 DatasetDict
- 若指定:返回单个 Dataset
cache_dirstr(可选)~/.cache/huggingface/datasets缓存目录,用于存储下载/处理后的数据集(避免重复下载)。
featuresFeatures(可选)None自定义数据集特征结构(如指定字段类型为文本、整数等),覆盖自动推断的特征。
download_configDownloadConfig(可选)None下载配置(如超时时间、代理等),控制数据下载的细节。
download_modeDownloadMode / str(可选)REUSE_DATASET_IF_EXISTS下载模式:
- REUSE_DATASET_IF_EXISTS(默认):若本地有缓存,直接复用;
- FORCE_REDOWNLOAD:强制重新下载;
- REUSE_CACHE_IF_EXISTS:复用缓存,若缓存损坏则重新下载。
verification_modeVerificationMode / str(可选)BASIC_CHECKS数据集校验模式:
- BASIC_CHECKS(默认):基础校验(文件存在、大小匹配);
- ALL_CHECKS:完整校验(含哈希值匹配);
- NO_CHECKS:不校验。
save_infos=True,默认升级为 ALL_CHECKS
keep_in_memorybool(可选)None是否将数据集加载到内存:
- None(默认):自动判断(小数据集加载到内存,大数据集从磁盘读取);
- True/False:强制加载/不加载到内存。
save_infosbool(可选)False是否保存数据集元信息(如校验和、大小、拆分信息)到缓存目录。
revisionstr / Version(可选)NoneHub 数据集的版本(如分支名 'main'、commit SHA、标签),用于加载特定版本的数据集。
tokenbool / str(可选)None访问私有 Hub 数据集的令牌:
- True:从 ~/.huggingface 读取令牌;
- 字符串:直接传入令牌。
streamingbool(可选)False是否流式加载:
- False(默认):下载并缓存完整数据集;
- True:不下载,实时流式读取(适合超大数据集,返回 IterableDataset)。
num_procint(可选)None并行处理的进程数(用于加速本地数据集的下载和预处理),默认禁用多进程。
storage_optionsdict(可选)None实验性参数,传递给文件系统后端的配置(如云端存储的访问密钥)。
trust_remote_codebool(可选)False是否信任 Hub 上的数据集脚本(若数据集包含自定义脚本)。
- True:执行远程脚本(仅信任可靠仓库);
- False(默认):拒绝执行远程脚本。
**config_kwargs额外关键字参数传递给数据集构建器(DatasetBuilder)的额外配置参数。

返回值说明

  • streaming=False(默认):
    - 若 splitNone:返回 DatasetDict(包含所有数据拆分,如 {'train': Dataset, 'test': Dataset});
    - 若 split 指定:返回单个 Dataset(对应拆分的数据)。
  • streaming=True
    - 若 splitNone:返回 IterableDatasetDict
    - 若 split 指定:返回 IterableDataset(流式迭代器,不缓存完整数据)。

核心特点

  1. 多来源支持:兼容 Hub 云端仓库、本地文件、自定义格式(图片、音频等)。
  2. 自动化处理:自动解析数据格式、推断特征类型、缓存处理结果(避免重复计算)。
  3. 灵活拆分:支持加载特定拆分(如仅训练集)或组合拆分(如 'train+validation')。
  4. 流式加载:适合超大数据集(不占用本地存储空间,实时读取)。
  5. 版本控制:通过 revision 参数指定 Hub 数据集的特定版本,确保可复现性。

示例场景

  • 从 Hub 加载公开数据集:load_dataset('rotten_tomatoes', split='train')
  • 从本地 CSV 文件加载:load_dataset('csv', data_files='./train.csv')
  • 流式加载超大数据集:load_dataset('large_dataset', streaming=True)
  • 加载指定版本的数据集:load_dataset('my_dataset', revision='v1.0', split='test')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

二分掌柜的

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值