一、介绍
torchvision
是 PyTorch 生态中用于计算机视觉的核心工具库,提供数据加载、模型架构、图像变换三大核心功能,大幅简化 CV 开发流程。以下是具体内容和作用:
1. datasets
:内置数据集
- 作用:快速加载经典 CV 数据集,支持自动下载、预处理和划分。
- 常见数据集:
- 分类:MNIST、CIFAR10/100、ImageNet、Caltech101/256
- 检测 / 分割:COCO、Pascal VOC、Cityscapes
- 其他:LSUN、SBD(语义边界数据集)
2. models
:预训练模型库
- 作用:提供预训练的经典模型,支持迁移学习和微调。
- 常见模型:
- 分类:ResNet、VGG、MobileNet、EfficientNet
- 检测:Faster R-CNN、SSD、YOLO(需第三方实现)
- 分割:U-Net、DeepLabV3、Mask R-CNN
- 生成:DCGAN、StyleGAN(部分支持)
3. transforms
:图像变换工具
- 作用:处理图像数据,包括预处理和数据增强。
- 常见变换:
- 预处理:
ToTensor()
、Normalize()
、Resize()
- 增强:
RandomCrop()
、RandomHorizontalFlip()
、ColorJitter()
- 其他:
Grayscale()
、GaussianBlur()
、RandomRotation()
- 预处理:
4. utils
:辅助工具
- 作用:提供可视化、数据处理等实用函数。
- 常见工具:
make_grid()
:将多张图像拼接成网格图(用于可视化)。save_image()
:保存张量图像到文件。draw_bounding_boxes()
:在图像上绘制检测框(需 torch >= 1.10)。
其文档
参数中
root(字符串类型):指定数据集的根目录。
train
(布尔类型,可选,默认 True
):决定加载训练集(True)还是测试集(False)。
transform(可调用对象,可选,默认 None
):对加载的 PIL 图像进行变换操作(传入一个函数或变换对象)。
target_transform
(可调用对象,可选,默认 None
):对标签(target,即图像对应的类别 )进行变换操作(传入一个函数或变换对象)。
download
(布尔类型,可选,默认 False
):决定是否从互联网下载数据集。download=True
时,会检查 root
目录下是否存在数据集,若不存在则从指定源下载;若已存在,则不会重复下载,方便离线使用已下载的数据集。
二、下载CIFAR-10数据集
import torchvision
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
因为我这里下载太慢,所以用迅雷下载完之后再解压缩到该文件夹中
执行
import torchvision
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
print(test_set[0])
print(test_set.classes)
设置断点,debug查看test_set属性以及CIFAR-10数据集
结果
第一个print,第一项为数据的图片,第二项为图片中的物体的target(是一个索引,根据这个索引在class中得到相应的target) 。
再执行
import torchvision
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
print(test_set[0])
print(test_set.classes)
img,target=test_set[0]
print(img)
print(target)
print(test_set.classes[target])
结果
img.show()展示图片
三、数据集和transforms结合
需要结合的情况
- 模型训练:
- 数据标准化:几乎所有深度学习模型都对输入数据的格式和范围有特定要求。例如,大多数预训练模型(如在 ImageNet 上预训练的模型)期望输入图像是张量格式,且像素值经过归一化处理。CIFAR - 10 数据集中的图像是 PIL 格式,将其转换为张量并归一化,能让模型更高效地学习。此时需要使用
transforms.ToTensor()
和transforms.Normalize()
等变换,结合数据集一起使用,以确保输入数据符合模型要求。 - 数据增强:在训练数据量有限的情况下,为了提高模型的泛化能力,需要对训练数据进行增强。比如,对图像进行随机裁剪、翻转、旋转、颜色抖动等操作,能增加数据的多样性,让模型学习到更多不同的特征表示。像
transforms.RandomCrop()
、transforms.RandomHorizontalFlip()
等数据增强变换,需要结合训练数据集使用,在每次加载训练数据时动态地对图像进行变换。
- 数据标准化:几乎所有深度学习模型都对输入数据的格式和范围有特定要求。例如,大多数预训练模型(如在 ImageNet 上预训练的模型)期望输入图像是张量格式,且像素值经过归一化处理。CIFAR - 10 数据集中的图像是 PIL 格式,将其转换为张量并归一化,能让模型更高效地学习。此时需要使用
- 模型评估与预测(特殊情况):
- 格式统一:即使在模型评估和预测阶段,为了保证输入数据与训练阶段的一致性,也需要对测试或预测数据进行标准化处理,将其转换为张量并归一化,这时就需要结合
transforms
。 - 特定预处理需求:如果模型在训练过程中对数据进行了一些特定的预处理(如调整图像大小到特定尺寸),那么在评估和预测时,也需要对数据进行同样的处理,以保证输入数据符合模型的输入要求。
- 格式统一:即使在模型评估和预测阶段,为了保证输入数据与训练阶段的一致性,也需要对测试或预测数据进行标准化处理,将其转换为张量并归一化,这时就需要结合
不需要结合的情况
- 数据简单查看与分析:
- 仅观察数据原貌:当只是想查看数据集中图像的原始样子,比如在数据探索阶段,想直观地浏览一下 CIFAR - 10 数据集中的图像,了解其内容和分布情况,不需要对图像进行任何格式转换或预处理,此时就不需要结合
transforms
。可以直接从数据集中获取 PIL 格式的图像进行显示。 - 进行简单统计分析:如果只是对数据集中图像的尺寸、颜色模式等基本信息进行统计分析,而不涉及模型训练和预测,也不需要使用
transforms
对图像进行变换。
- 仅观察数据原貌:当只是想查看数据集中图像的原始样子,比如在数据探索阶段,想直观地浏览一下 CIFAR - 10 数据集中的图像,了解其内容和分布情况,不需要对图像进行任何格式转换或预处理,此时就不需要结合
- 使用其他数据处理库:
- 已在其他库中完成处理:如果在数据加载到
torchvision
的数据集之前,已经使用其他图像处理库(如 OpenCV)对数据进行了完整的预处理和变换,并且数据已经符合模型的输入要求,那么在使用torchvision
的数据集时,就不需要再结合transforms
进行重复处理。
- 已在其他库中完成处理:如果在数据加载到
若仅需 “观察图像大致结果”(可视化、尺寸 / 模式 ):无需转 Tensor,直接用 PIL 操作更便捷。
若需 “观察数值特征”(像素分布、参与模型计算 ):转 Tensor 是必要的。
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
writer=SummaryWriter("p10")
for i in range(10):
img,target=test_set[i]
writer.add_image("test_set",img,i)
writer.close()
查看日志
等十张图片