在自然语言处理(NLP)和机器学习领域,Hugging Face 的 datasets
库凭借其高效的数据加载和预处理能力成为开发者必备工具。本文通过代码示例详解 load_dataset
的核心用法,涵盖数据格式解析、公开数据集调用和自定义数据集构建。
一、数据格式解析与加载示例
datasets
库支持多种数据格式,包括 JSON、CSV、Parquet、Arrow 等,通过指定格式参数直接加载:
-
JSON 文件加载
处理嵌套字段时需指定field
参数:from datasets import load_dataset # 加载嵌套结构的 JSON 数据 dataset = load_dataset("json", data_files="data.json", field="data.items")
-
CSV 文件分拆训练/测试集
通过字典映射文件路径:data_files = {"train": "train.csv", "test": "test.csv"} dataset = load_dataset("csv", data_files=data_files)
-
Parquet 高效列式存储
适合大规模数据集:dataset = load_dataset("parquet", data_files="data.parquet")
-
内存数据快速构建
从 Python 字典或列表直接创建:data = {"text": ["sample1", "sample2"], "label": [0, 1]} dataset = Dataset.from_dict(data)
二、公开数据集调用实战
Hugging Face Hub 提供超过 40,000 个公开数据集,覆盖 NLP、CV、音频等领域:
-
NLP 任务:GLUE-MRPC
加载微软研究院的句子对分类数据集:dataset = load_dataset("glue", "mrpc", split="train") # 查看样例:句子对 + 二分类标签 print(dataset[0]) # {'sentence1': ..., 'sentence2': ..., 'label': 1}
-
问答任务:SQuAD
加载斯坦福问答数据集:dataset = load_dataset("squad", split="train") # 数据结构:context + question + answer
-
多模态:COCO 图像描述
加载图像-文本匹配数据:dataset = load_dataset("coco_captions", split="train") # 包含 PIL Image 对象和文本描述
-
自定义子集筛选
对大型数据集按条件过滤:imdb = load_dataset("imdb") positive_reviews = imdb.filter(lambda x: x["label"] == 1)
三、自定义数据集构建指南
当需要处理专有数据时,可通过以下方式构建数据集:
-
本地文件加载
结构化数据直接映射:# 加载多 JSONL 文件 dataset = load_dataset("json", data_files="*.jsonl", split="train")
-
动态内存构建
从数据库或 API 实时生成:import pandas as pd df = pd.read_csv("custom_data.csv") dataset = Dataset.from_pandas(df)
-
复杂预处理 Pipeline
结合map
函数实现数据增强:def preprocess(example): example["text"] = tokenizer(example["text"], truncation=True) return example dataset = dataset.map(preprocess, batched=True, num_proc=8)
-
分布式优化技巧
处理超大规模数据时:# 保存为 Arrow 格式加速后续加载 dataset.save_to_disk("processed_data") dataset = load_from_disk("processed_data")
四、性能优化与高级功能
-
流式加载(Streaming Mode)
处理超大数据避免内存溢出:dataset = load_dataset("bigscience/P3", streaming=True) for example in dataset.take(1000): process(example)
-
数据集版本控制
确保实验可复现性:dataset = load_dataset("my_dataset", revision="v1.2")
-
跨格式转换
与 PyTorch/TensorFlow 生态交互:pt_tensor = dataset.with_format("torch")["feature"]