Hugging Face datasets.Dataset
类详解
datasets.Dataset
是 Hugging Face datasets
库中的 核心数据结构,用于表示 结构化的 NLP 数据集,支持高效的数据处理、转换、筛选、格式化、保存与加载等操作。每个 Dataset
类似于一个增强版的 pandas.DataFrame
,但设计为支持大规模、高性能、可流式处理的训练数据集。
1. 创建和加载 Dataset
对象
1.1 从 Hugging Face Hub 加载现成数据集
from datasets import load_dataset
dataset = load_dataset("imdb") # 返回的是 DatasetDict
train_dataset = dataset["train"] # 获取具体的 Dataset 对象
1.2 从本地 CSV / JSON 文件加载
dataset = load_dataset("csv", data_files="data.csv")
1.3 自定义数据加载
from datasets import Dataset
data = {
"text": ["I love AI.", "Transformers are powerful."],
"label": [1, 1]
}
dataset = Dataset.from_dict(data)
2. Dataset
的常见属性
属性 | 作用 | 示例 |
---|---|---|
dataset.column_names | 列名 | ['text', 'label'] |
dataset.features | 特征结构 | {'text': Value('string'), 'label': Value('int64')} |
dataset.num_rows | 样本数 | dataset.num_rows |
dataset.shape | (num_rows, num_columns) | (2, 2) |
dataset[0] | 访问单条数据(字典) | {'text': 'I love AI.', 'label': 1} |
3. Dataset
的常用方法
方法 | 作用 |
---|---|
map(function) | 应用于每一行,进行批量处理或特征工程 |
filter(function) | 按条件过滤样本 |
select(indices) | 按索引选择样本 |
train_test_split() | 拆分训练集与测试集 |
shuffle(seed) | 随机打乱数据集 |
remove_columns(columns) | 删除某些列 |
rename_column(old, new) | 重命名列 |
with_format("torch"/"tensorflow") | 转换为深度学习框架格式 |
to_pandas() | 转为 pandas |
to_csv() | 保存为 CSV |
save_to_disk() / load_from_disk() | 本地存储与加载 |
4. 示例:预处理文本数据
def tokenize_function(example):
return tokenizer(example["text"], truncation=True, padding="max_length")
tokenized_dataset = dataset.map(tokenize_function)
5. 示例:过滤文本长度
short_texts = dataset.filter(lambda example: len(example["text"]) < 100)
6. 示例:转换为 PyTorch 可用格式
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
7. 示例:拆分训练和验证集
split_dataset = dataset.train_test_split(test_size=0.2)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
8. 流式数据加载(大数据集)
适用于内存无法加载的大型数据集:
dataset = load_dataset("wikipedia", split="train", streaming=True)
for example in dataset:
print(example)
break
9. 示例:保存与加载本地数据
dataset.save_to_disk("./my_dataset")
from datasets import load_from_disk
dataset = load_from_disk("./my_dataset")
10. 总结
Dataset
是一个非常灵活的对象,适合以下操作:
场景 | 方法 |
---|---|
数据清洗 | .map() , .filter() |
选择样本 | .select() , .shuffle() |
拆分数据 | .train_test_split() |
格式转换 | .set_format() , .to_pandas() |
本地管理 | .save_to_disk() , .load_from_disk() |
如果你正在训练 Transformer 模型,几乎所有的数据预处理都可以直接在 datasets.Dataset
上完成,避免手动构造 DataLoader
或转换为 Tensor 的麻烦。