这是我的第409篇原创文章。
一、引言
在深度学习任务中,数据加载和处理是至关重要的一环。
PyTorch 提供了强大的数据加载和处理工具,主要包括:
-
torch.utils.data.Dataset
:数据集的抽象类,需要自定义并实现__len__
(数据集大小)和__getitem__
(按索引获取样本)。 -
torch.utils.data.TensorDataset
:基于张量的数据集,适合处理数据-标签对,直接支持批处理和迭代。 -
torch.utils.data.DataLoader
:封装 Dataset 的迭代器,提供批处理、数据打乱、多线程加载等功能,便于数据输入模型训练。 -
torchvision.datasets.ImageFolder
:从文件夹加载图像数据,每个子文件夹代表一个类别,适用于图像分类任务。
二、实现过程
2.1 torch.utils.data.Dataset
Dataset 是 PyTorch 中用于数据集抽象的类。自定义数据集需要继承 torch.utils.data.Dataset 并重写以下两个方法:
-
__len__:返回数据集的大小。
-
__getitem__:按索引获取一个数据样本及其标签。
代码:
import torch
from torch.utils.data import Dataset
# 自定义数据集
class MyDataset(Dataset):
def __init__(self, data, labels):
# 数据初始化
self.data = data
self.labels = labels
def __len__(self):
# 返回数据集大小
return len(self.data)
def __getitem__(self, idx):
# 按索引返回数据和标签
sample = self.data[idx]
label = self.labels[idx]
return sample, label
# 生成示例数据
data = torch.randn(100, 5) # 100 个样本,每个样本有 5 个特征
labels = torch.randint(0, 2, (100,)) # 100 个标签,取值为 0 或 1
# 实例化数据集
dataset = MyDataset(data, labels)
# 测试数据集
print("数据集大小:", len(dataset))
print("第 0 个样本:", dataset[0])
输出结果:
2.2 torch.utils.data.DataLoader
DataLoader 是 PyTorch 提供的数据加载器,用于批量加载数据集。提供了以下功能:
-
批量加载:通过设置
batch_size
。 -
数据打乱:通过设置
shuffle=True
。 -
多线程加速:通过设置
num_workers
。 -
迭代访问:方便地按批次访问数据。
代码:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
# 自定义数据集
class MyDataset(Dataset):
def __init__(self, data, labels):
# 数据初始化
self.data = data
self.labels = labels
def __len__(self):
# 返回数据集大小
return len(self.data)
def __getitem__(self, idx):
# 按索引返回数据和标签
sample = self.data[idx]
label = self.labels[idx]
return sample, label
# 生成示例数据
data = torch.randn(100, 5) # 100 个样本,每个样本有 5 个特征
labels = torch.randint(0, 2, (100,)) # 100 个标签,取值为 0 或 1
# 实例化数据集
dataset = MyDataset(data, labels)
# 实例化 DataLoader
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0)
# 遍历 DataLoader
for batch_idx, (batch_data, batch_labels) in enumerate(dataloader):
print(f"批次 {batch_idx + 1}")
print("数据:", batch_data)
print("标签:", batch_labels)
if batch_idx == 2: # 仅显示前 3 个批次
break
输出结果:
每次循环中,DataLoader 会返回一个批次的数据,包括输入特征(inputs)和目标标签(labels)。
2.3 Dataset 与 DataLoader 的自定义应用
以下是一个将 CSV 文件 作为数据源,并通过自定义 Dataset 和 DataLoader 读取数据。
代码:
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
# 自定义 CSV 数据集
class CSVDataset(Dataset):
def __init__(self, file_path):
# 读取 CSV 文件
self.data = pd.read_csv(file_path)
def __len__(self):
# 返回数据集大小
return len(self.data)
def __getitem__(self, idx):
# 使用 .iloc 明确基于位置索引
row = self.data.iloc[idx]
# 将特征和标签分开
features = torch.tensor(row.iloc[:-1].to_numpy(), dtype=torch.float32) # 特征
label = torch.tensor(row.iloc[-1], dtype=torch.float32) # 标签
return features, label
# 实例化数据集和 DataLoader
dataset = CSVDataset("runoob_pytorch_data.csv")
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# 遍历 DataLoader
for features, label in dataloader:
print("特征:", features)
print("标签:", label)
break
输出结果:
作者简介:
读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。