课程地址
最近做实验发现自己还是基础框架上掌握得不好,于是开始重学一遍PyTorch框架,这个是课程笔记,此节课很详细,笔记记的比较粗,这个视频课是需要有点深度学习数学基础的,如果没有数学基础,可以一边学一边查一查
1. Transforms
我们导入到数据集中的图片可能大小不一样,数据并不总是以训模型所需的最终处理形式出现。我们使用Transforms对数据进行一些操作,使其适合训练(比如统一图片的像素值)。一般Transforms函数都在Dataset中去定义好。target_transfrom是同样的道理,比如我们得到的分类结果是整型的,然后我们能够做一个target_transfrom处理变成One-hot这种数据格式。无论是Transforms还是target_transform,我们都要在Dataset中去定义好。最后通过get_item方法返回到正确的符合规范的标签或样本。
在深度学习中,One-hot编码是一种将分类变量转换成向量表示的方法。在One-hot编码中,向量的长度与类别总数相同,而每个类别对应一个位置,在该位置上One-hot编码表示为1,而其余位置则表示为0。这种表示方法确保了每个类别之间是相互独立的,不存在任何顺序或距离的关系。
One-hot编码在处理分类特征时非常有用,特别是当特征的取值类别非常多时。它能够为神经网络提供一个清晰的输入表示,有助于模型的学习和泛化。例如,在处理文本数据时,每个单词可以被视为一个单独的分类特征,而One-hot编码则可以将每个单词转换为一个向量,其中只有在该单词对应的索引位置上有一个1,其余位置都是0。
使用One-hot编码的主要步骤通常包括:
- 构建词汇表:收集所有可能的类别,并按照某种顺序进行排序。
- 将文本数据转换为整数序列:将文本中的每个单词或字符映射到词汇表中的相应索引。
- 获取词汇表大小:确定One-hot编码向量的长度,即类别总数。
- 将整数序列转换为One-hot编码:对于每个整数序列,创建一个长度与词汇表大小相同的向量,并在相应的位置上设置为1,其他位置设置为0。
需要注意的是,虽然One-hot编码在某些情况下非常有用,但在处理具有顺序或距离关系的数据时,它可能不是最佳选择。此外,当类别数量非常多时,One-hot编码可能会导致大量的零值,这可能会对模型的训练和性能产生影响。在这种情况下,可能需要考虑其他编码方法,如嵌入向量等。
# 导入必要的库和模块
import torch
from torchvision import datasets # torchvision 中的 datasets 模块,用于加载标准数据集
from torchvision.transforms import ToTensor, Lambda # torchvision 的 transforms 模块提供数据预处理的方法
# 加载 FashionMNIST 数据集
ds = datasets.FashionMNIST(
root="data", # 数据存储的根目录,如果数据未下载则会下载到这个目录下
train=True, # 指定这是训练数据集,若设为 False 则加载测试数据集
download=True, # 如果数据未在 'root' 指定的路径下,则下载数据
transform=ToTensor(), # 指定一个转换函数,将 PIL 图片或者 numpy.ndarray 转换为 FloatTensor,并且归一化到 [0, 1] 区间
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
# target_transform 是一个用于处理目标变量的转换函数,
# 这里使用 Lambda 创建一个匿名函数,它将整数标签 y 转换为一个 one-hot 编码的 tensor。
# 具体过程如下:
# 1. torch.zeros(10, dtype=torch.float) 创建一个长度为 10,类型为 float 的零向量,
# 对应于 FashionMNIST 数据集中的 10 个类别。
# 2. scatter_(0, torch.tensor(y), value=1) 将上述零向量的第 y 位置设为 1,实现 one-hot 编码。
# scatter_ 的第一个参数是维度,这里为 0,表示在零维上操作(这里向量只有一维)。
)
2. 构建一个神经网络模型
神经网络由对数据执行操作的层/模块组成。torch.nn命名空间提供了构建自己的神经网络所需的所有构建块。PyTorch中的每个模块都将nn子类化。单元神经网络是由其他模块(层)组成的模块本身。这种嵌套结构允许轻松地构建和管理复杂的体系结构。下面的代码讲解分类模型常用的东西,内容都在注释中:
import os
import torch
from torch import nn