pytorch加载数据集主要分为两种方法:
1、所使用数据集已被集成在pytorch内,如:CIFAR-10,CIFAR-100,MNIST等等。对于这种数据集,可以直接使用pytorch内置函数:torchvision.datasets.CIFAR100来直接加载,比较方便。例程如下:
transform_train = transforms.Compose([
#transforms.ToPILImage(),
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True,
download=True, transform=transform_train)
cifar100_training_loader = DataLoader(
cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
2、所使用数据集为被集成,这个类别是本文的主要讲述内容。
加载自定义数据集(即未被集成在pytorch内)
对于自定义数据集pytorch实际上是有一个函数的:torchvision.datasets.ImageFolder(),但是此函数只能加载特定形式的数据集(图片已被分类好,并放在相应文件夹下了,其标签就是其上层目录的名称,在下面会解释为什么)。有时候我们需要使用的数据集可能不是这样的,其标签可能不在目录上,而在一个