【学习笔记】【Pytorch】十六、模型训练套路
一、内容概述
本内容主要是介绍一个完整的模型训练套路,以 CIFAR-10 数据集为例。
模型训练步骤:
- 准备数据:创建 datasets 实例
- 加载数据:创建 DataLoader 实例
- 准备模型:神经网络结构
- 设置损失函数
- 设置优化器
- 开始训练:
- 从 batch_size 个数据中分别取出个图片数据和标签数据
- 图片数据输入到神经网络模型里后输出训练结果
- 将训练结果和标签数据经过损失函数
- 调用优化器实例的梯度清零API
- 调用反向传播API
- 调用优化器实例参数优化API
- 开始测试:(使用每轮训练好、但不进行优化的模型)
- 结果聚合展示
二、模型训练套路
1.代码实现:CPU版本
# CIFAR_model.py
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class Model(nn.Module):
def __init__(self) -> None:
super().__init__() # 初始化父类属性
self.conv1 = Conv2d(in_channels=3, out_channels=32,
kernel_size=5, padding=2)
self.maxpool1 = MaxPool2d(kernel_size=2)
self.conv2 = Conv2d(in_channels=32,out_channels=32,
kernel_size=5, padding=2)
self.maxpool2 = MaxPool2d(kernel_size=2)
self.conv3 = Conv2d(in_channels=32, out_channels=64,
kernel_size=5, padding=2)
self.maxpool3 = MaxPool2d(kernel_size=2)
self.flatten = Flatten() # 展平为1维向量,torch.reshape()一样效果
# 若是想检验1024是否正确,可以先写前面的层,看样例的输出大小,即可得到1024
self.linear1 = Linear(in_features=1024, out_features=64)
self.linear2 = Linear(in_features=64, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.conv3(x)
x = self.maxpool3(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.linear2(x)
return x
if __name__ == '__main__':
model = Model() # 创建实例
# 测试模型样例(也可以测试各层的输出是否正确)
input = torch.ones((64, 3, 32, 32)) # batch_size = 64
print(input.shape) # torch.Size([64, 3, 32, 32])
output = model(input)
print(output.shape) # torch.Size([64, 10]),batch_size=64,10个参数
import time
import torch.optim.optimizer
import torchvision
from torch import nn, optim
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from CIFAR_model import Model # 导入CIFAR_model.py里的Model类定义
# 1.创建 CIFAR10 数据集的训练和测试实例
train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True,
transform=torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
# Length 长度
train_data_len = len(train_data)
test_data_len = len(test_data)
print(f"训练数据集的长度:{
train_data_len}")
print(f"测试数据集的长度:{
test_data_len}")
# 2.利用 DataLoader 加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size