用于处理数据样本的代码可能会变得混乱且难以维护。理想情况下,我们希望将数据集代码与模型训练代码解耦,以提高可读性和模块化程度。PyTorch 提供了两个数据基础工具:torch.utils.data.DataLoader
和 torch.utils.data.Dataset
,它们允许您使用预加载的数据集以及自定义的数据。Dataset
存储样本及其对应的标签,而 DataLoader
则为 Dataset
提供了一个可迭代对象,方便访问样本数据。
PyTorch 的领域库提供了许多预加载的数据集(如 FashionMNIST),这些数据集继承自 torch.utils.data.Dataset
,并实现了特定数据的函数。它们可以用于快速构建和评估您的模型。
加载数据集
以下是如何从 TorchVision 加载 Fashion-MNIST 数据集的示例。Fashion-MNIST 是一个包含 Zalando 商品图像的数据集,由 60,000 个训练样本和 10,000 个测试样本组成。每个样本包含一个 28×28 的灰度图像及其在 10 个类别中的标签。
我们使用以下参数来加载 FashionMNIST 数据集:
root
:存储训练/测试数据的路径。train
:指定是训练数据集还是测试数据集。download=True
:如果在root
路径中数据不可用,则从互联网下载数据。transform
和target_transform
:指定特征和标签的转换。
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data", #相对工程文件的地址
train=True, #训练集,所以该参数设置为True
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data", #相对工程文件的地址
train=False, #测试集,所以该参数设置为False
download=True,
transform=ToTensor()
)
文件结构: