Hugging Face datasets.DatasetDict
类详解
DatasetDict
是 Hugging Face datasets
库中用于 管理多个 Dataset
子集(如 "train"
、"test"
、"validation"
)的字典容器结构。
它是一个继承自 Python dict
的对象,其中每个 key 对应一个 datasets.Dataset
对象。
1. 什么是 DatasetDict
?
它的结构类似于:
{
"train": Dataset,
"test": Dataset,
"validation": Dataset
}
用于统一管理和访问训练、测试、验证等数据子集。
2. 创建与加载 DatasetDict
2.1 从 Hugging Face Hub 加载数据集
from datasets import load_dataset
dataset_dict = load_dataset("imdb")
print(dataset_dict)
输出示例:
DatasetDict({
train: Dataset({
features: ['text', 'label'],
num_rows: 25000
})
test: Dataset({
features: ['text', 'label'],
num_rows: 25000
})
})
3. 访问与使用 DatasetDict
用法 | 说明 | 示例 |
---|---|---|
dataset_dict["train"] | 获取训练集 | Dataset 对象 |
dataset_dict.keys() | 获取所有子集名称 | ['train', 'test'] |
len(dataset_dict["train"]) | 子集样本数 | 25000 |
dataset_dict["train"][0] | 获取一条数据 | {'text': ..., 'label': ...} |
4. 常用方法
方法 | 作用 |
---|---|
map(function) | 对所有子集应用预处理函数 |
filter(function) | 对所有子集进行筛选 |
remove_columns(cols) | 删除列 |
rename_column(old, new) | 重命名列 |
set_format("torch") | 转换格式以用于 PyTorch |
train_test_split() | 拆分训练集与验证集(只对某个子集) |
shuffle(seed) | 随机打乱所有子集 |
save_to_disk(path) | 保存整个数据字典 |
load_from_disk(path) | 加载数据字典 |
5. 示例用法
5.1 数据预处理
def tokenize_fn(example):
return tokenizer(example["text"], truncation=True, padding="max_length")
tokenized_dataset = dataset_dict.map(tokenize_fn, batched=True)
5.2 删除原始列
tokenized_dataset = tokenized_dataset.remove_columns(["text"])
5.3 拆分训练集为训练 + 验证
split = dataset_dict["train"].train_test_split(test_size=0.1)
dataset_dict["train"] = split["train"]
dataset_dict["validation"] = split["test"]
5.4 设置输出格式为 PyTorch
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
5.5 保存与加载
tokenized_dataset.save_to_disk("./my_dataset_dict")
from datasets import load_from_disk
dataset_dict = load_from_disk("./my_dataset_dict")
6. DatasetDict
vs Dataset
对象 | 描述 | 示例 |
---|---|---|
Dataset | 单一子集(如训练集) | dataset_dict["train"] |
DatasetDict | 多个子集的集合(如 train/test/val) | dataset_dict |
7. 示例结构
{
'train': Dataset(num_rows=1000, features=['text', 'label']),
'test': Dataset(num_rows=200, features=['text', 'label']),
'validation': Dataset(num_rows=200, features=['text', 'label'])
}
你可以像字典一样访问:
dataset_dict["train"]
dataset_dict["validation"]
也可以统一处理所有子集:
dataset_dict = dataset_dict.map(tokenize_fn, batched=True)
8. 总结
能力 | 是否支持 |
---|---|
管理多个数据子集 | 是 |
与 Dataset 同步处理 | 是 |
批量预处理、筛选、格式转换 | 是 |
可存储到本地 | 是 |
兼容 Trainer 训练框架 | 是 |
DatasetDict
是使用 Hugging Face 数据训练模型时的推荐数据结构,特别适合在 训练/验证/测试阶段统一管理数据。