目录
数据加载
若数据集里无分类文件,全是照片,用ImageFolder()时,应在数据集文件里自己建一个文件夹,在写数据路径时应写到上一层文件夹.
如下例路径
路径应为
案例一:猫狗分类
参考:https://blog.csdn.net/qq_42951560/article/details/109950786.
代码位置:E:\项目例程\猫狗分类\迁移学习\猫狗_resnet18_2 \猫狗分类_迁移学习可视化
数据集展示:
数据增强
#---数据增强----
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
数据读取,加载
当数据集已经自觉按照要分配的类型分成了不同的文件夹,一种类型的文件夹下面只存放一种类型的图片,定义数据读取时,使用 torchvision包中的ImageFolder类会比Dataset类会更方便
#----制作数据集----
image_datasets = {
x: datasets.ImageFolder(
root=os.path.join('./catsdogs', x),
transform=data_transforms[x]
) for x in ['train', 'val']
}
#----数据加载器---
dataloaders = {
x: DataLoader(
dataset=image_datasets[x],
batch_size=16,
shuffle=True,
num_workers=0
) for x in ['train', 'val']
}
查看数据集数量及种类
#------相关信息打印-----
dataset_sizes = {
x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(dataset_sizes)
print(class_names)
print(device)
结果展示:
训练数据可视化
#----训练数据可视化-----
inputs, labels = next(iter(dataloaders['train']))
grid_images = torchvision.utils.make_grid(inputs)
def no_normalize(im):
im = im.permute(1, 2, 0)
im = im*torch.Tensor([0.229, 0.224, 0.225])+torch.Tensor([0.485, 0.456, 0.406])
return im
grid_images = no_normalize(grid_images)
plt.title([class_names[x] for x in labels])
plt.imshow(grid_images)
plt.show()
结果展示:
案例二:交通指示牌识别-4分类
参考:b站交通指示牌4分类迁移学习
代码位置:E:\项目例程\交通指示灯\迁移学习_交通道路识别\交通指示牌识别4分类_迁移学习
数据集展示
查看数据集
查看数据集大小
注意:使用opencv加载图片,路径应为全英文
Train_Path= r'E:\data\GTSRB\Final_Training\Images\*'
Test_Path = r'E:\data\GTSRB\Final_Test\Images\*'
# 训练集文件夹
train_folders = sorted(glob(Train_Path)) #根据路径读取文件
print("训练集类别",len(train_folders))
# 测试集文件
test_files = sorted(glob(Test_Path))
print("测试集数据"<