我们将通过一个简单的示例,快速了解如何使用 PyTorch 进行机器学习任务。PyTorch 是一个开源的机器学习库,它提供了丰富的工具和库,帮助我们轻松地构建、训练和测试神经网络模型。以下是本教程的主要内容:
一、数据处理
PyTorch 提供了两个基本的数据处理工具:torch.utils.data.DataLoader
和 torch.utils.data.Dataset
。Dataset
用于存储样本及其对应的标签,而 DataLoader
则为 Dataset
提供了一个可迭代的包装器。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
PyTorch 还提供了许多特定领域的库,如 TorchText、TorchVision 和 TorchAudio,这些库中都包含了各种数据集。在本教程中,我们将使用 TorchVision 中的 FashionMNIST 数据集。每个 TorchVision Dataset
都包含两个参数:transform
和 target_transform
,分别用于修改样本和标签。
# 下载 FashionMNIST 数据集
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
我们将 Dataset
作为参数传递给 DataLoader
。DataLoader
为我们的数据集提供了一个可迭代的包装器,并支持自动批处理、采样、洗牌和多进程数据加载。在这里,我们定义了一个大小为 64 的批次,即数据加载器可迭代对象的每个元素将返回一个包含 64 个特征和标签的批次。
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
二、创建模型
在 PyTorch 中,我们通过创建一个继承自 nn.Module
的类来定义神经网络。我们在 __init__
函数中定义网络的层,并在 forward
函数中指定数据如何通过网络传递