Torchvision

一、介绍

torchvision 是 PyTorch 生态中用于计算机视觉的核心工具库,提供数据加载、模型架构、图像变换三大核心功能,大幅简化 CV 开发流程。以下是具体内容和作用:

1. datasets:内置数据集
  • 作用:快速加载经典 CV 数据集,支持自动下载、预处理和划分。
  • 常见数据集
    • 分类:MNIST、CIFAR10/100、ImageNet、Caltech101/256
    • 检测 / 分割:COCO、Pascal VOC、Cityscapes
    • 其他:LSUN、SBD(语义边界数据集)
2. models:预训练模型库
  • 作用:提供预训练的经典模型,支持迁移学习和微调。
  • 常见模型
    • 分类:ResNet、VGG、MobileNet、EfficientNet
    • 检测:Faster R-CNN、SSD、YOLO(需第三方实现)
    • 分割:U-Net、DeepLabV3、Mask R-CNN
    • 生成:DCGAN、StyleGAN(部分支持)
3. transforms:图像变换工具
  • 作用:处理图像数据,包括预处理数据增强
  • 常见变换
    • 预处理ToTensor()Normalize()Resize()
    • 增强RandomCrop()RandomHorizontalFlip()ColorJitter()
    • 其他Grayscale()GaussianBlur()RandomRotation()
4. utils:辅助工具
  • 作用:提供可视化、数据处理等实用函数。
  • 常见工具
    • make_grid():将多张图像拼接成网格图(用于可视化)。
    • save_image():保存张量图像到文件。
    • draw_bounding_boxes():在图像上绘制检测框(需 torch >= 1.10)。

其文档

 参数中

root(字符串类型):指定数据集的根目录。

train(布尔类型,可选,默认 True ):决定加载训练集(True)还是测试集(False)。

transform(可调用对象,可选,默认 None ):对加载的 PIL 图像进行变换操作(传入一个函数或变换对象)。

target_transform(可调用对象,可选,默认 None ):对标签(target,即图像对应的类别 )进行变换操作(传入一个函数或变换对象)。

download(布尔类型,可选,默认 False ):决定是否从互联网下载数据集。download=True 时,会检查 root 目录下是否存在数据集,若不存在则从指定源下载;若已存在,则不会重复下载,方便离线使用已下载的数据集。

二、下载CIFAR-10数据集

import torchvision

train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)

因为我这里下载太慢,所以用迅雷下载完之后再解压缩到该文件夹中

 执行

import torchvision

train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)

print(test_set[0])
print(test_set.classes)

 设置断点,debug查看test_set属性以及CIFAR-10数据集

结果

第一个print,第一项为数据的图片,第二项为图片中的物体的target(是一个索引,根据这个索引在class中得到相应的target) 。

再执行

import torchvision

train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)

print(test_set[0])
print(test_set.classes)

img,target=test_set[0]
print(img)
print(target)
print(test_set.classes[target])

结果 

img.show()展示图片

三、数据集和transforms结合

需要结合的情况

  • 模型训练
    • 数据标准化:几乎所有深度学习模型都对输入数据的格式和范围有特定要求。例如,大多数预训练模型(如在 ImageNet 上预训练的模型)期望输入图像是张量格式,且像素值经过归一化处理。CIFAR - 10 数据集中的图像是 PIL 格式,将其转换为张量并归一化,能让模型更高效地学习。此时需要使用transforms.ToTensor()transforms.Normalize()等变换,结合数据集一起使用,以确保输入数据符合模型要求。
    • 数据增强:在训练数据量有限的情况下,为了提高模型的泛化能力,需要对训练数据进行增强。比如,对图像进行随机裁剪、翻转、旋转、颜色抖动等操作,能增加数据的多样性,让模型学习到更多不同的特征表示。像transforms.RandomCrop()transforms.RandomHorizontalFlip()等数据增强变换,需要结合训练数据集使用,在每次加载训练数据时动态地对图像进行变换。
  • 模型评估与预测(特殊情况)
    • 格式统一:即使在模型评估和预测阶段,为了保证输入数据与训练阶段的一致性,也需要对测试或预测数据进行标准化处理,将其转换为张量并归一化,这时就需要结合transforms
    • 特定预处理需求:如果模型在训练过程中对数据进行了一些特定的预处理(如调整图像大小到特定尺寸),那么在评估和预测时,也需要对数据进行同样的处理,以保证输入数据符合模型的输入要求。

不需要结合的情况

  • 数据简单查看与分析
    • 仅观察数据原貌:当只是想查看数据集中图像的原始样子,比如在数据探索阶段,想直观地浏览一下 CIFAR - 10 数据集中的图像,了解其内容和分布情况,不需要对图像进行任何格式转换或预处理,此时就不需要结合transforms。可以直接从数据集中获取 PIL 格式的图像进行显示。
    • 进行简单统计分析:如果只是对数据集中图像的尺寸、颜色模式等基本信息进行统计分析,而不涉及模型训练和预测,也不需要使用transforms对图像进行变换。
  • 使用其他数据处理库
    • 已在其他库中完成处理:如果在数据加载到torchvision的数据集之前,已经使用其他图像处理库(如 OpenCV)对数据进行了完整的预处理和变换,并且数据已经符合模型的输入要求,那么在使用torchvision的数据集时,就不需要再结合transforms进行重复处理。

若仅需 “观察图像大致结果”(可视化、尺寸 / 模式 ):无需转 Tensor,直接用 PIL 操作更便捷。

若需 “观察数值特征”(像素分布、参与模型计算 ):转 Tensor 是必要的。

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)

writer=SummaryWriter("p10")
for i in range(10):
    img,target=test_set[i]
    writer.add_image("test_set",img,i)
writer.close()

 查看日志

等十张图片

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值