Load a Computer Vision Dataset in PyTorch
Last Updated :
28 Apr, 2025
Computer vision is a subset of Artificial Intelligence that gives the ability to the computer to understand images. In Deep Learning, Convolution Neural Network is used to process the image. For building the good we need a lot of images to process.
There are several ways to load a computer vision dataset in PyTorch, depending on the format of the dataset and the specific requirements of your project.
One popular method is to use the built-in PyTorch dataset classes, such as torchvision.datasets.'It provides a convenient way to load and preprocess common computer vision datasets, such as CIFAR-10 and ImageNet. For example, to load the CIFAR-10 dataset, you can use the following code:
Python3
# Import the necessary library
import torchvision.datasets as datasets
# Download the cifar Dataset
cifar10_train = datasets.CIFAR10(root="./data", train=True, download=True)
cifar10_test = datasets.CIFAR10(root="./data", train=False, download=True)
Output:
CIFAR-10
The code above will download the CIFAR-10 dataset and save it in the './data' directory.
Another method is using the 'torch.utils.data.DataLoader class to load the data. This is more useful when the data is in your local machine and you would like to have the power of data augmentation and the ability to shuffle the data and also have the ability to specify the batch size. it has the advantages of customizing data loading order, batching, single or multi-process data loading, etc.
Here we can use transform.Compose function from torchvision to rotate, flip, normalize and convert it into tensor form from the image.
Python3
# Import the necessary library
from torchvision import transforms
from torch.utils.data import DataLoader
# Image Transformation
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize([0.35, 0.35, 0.406], [0.30, 0.34, 0.35])
])
# Load the dataset with transformation
cifar10_train = datasets.CIFAR10(root="./data", train=True, download=False, transform=transform)
cifar10_test = datasets.CIFAR10(root="./data", train=False, download=False, transform=transform)
# Make the batch of size 16
train_loader = DataLoader(cifar10_train, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(cifar10_test, batch_size=32, shuffle=False, num_workers=2)
View the train and test data
Python3
#Train Dataset
print(train_loader.dataset)
#Test Dataset
print(test_loader.dataset)
Output:
Dataset CIFAR10
Number of datapoints: 50000
Root location: ./data
Split: Train
StandardTransform
Transform: Compose(
RandomHorizontalFlip(p=0.5)
RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
ToTensor()
Normalize(mean=[0.35, 0.35, 0.406], std=[0.3, 0.34, 0.35])
)
Dataset CIFAR10
Number of datapoints: 10000
Root location: ./data
Split: Test
StandardTransform
Transform: Compose(
RandomHorizontalFlip(p=0.5)
RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
ToTensor()
Normalize(mean=[0.35, 0.35, 0.406], std=[0.3, 0.34, 0.35])
)
Plot the image:
Python3
# Iteration
inputs, Class = next(iter(train_loader))
#Define the class names
class_name ={0:'airplane',
1:'automobile',
2:'bird',
3:'cat',
4:'deer',
5:'dog',
6:'frog',
7:'horse',
8:'ship',
9:'truck'
}
#Plot the figure
plt.figure(figsize=(30,16), dpi=1000)
for i in range(32):
plt.subplot(4,8,i+1)
plt.imshow(inputs[i].numpy().transpose((1, 2, 0)))
plt.axis('off')
plt.title(class_name[int(Class[i])])
plt.show()
Output:
CIFAR-10
The other libraries like 'albumentations' , can be used to load the dataset and preprocess the data. It all depends on the format of your data and what you are trying to achieve
You might also want to check the version of PyTorch you're using, as well as the format of the dataset you're trying to load. Some datasets might be in a custom format and you might need to write your own code to load it correctly.
Similar Reads
How to load CIFAR10 Dataset in Pytorch? The CIFAR-10 dataset is a popular resource for training machine learning models, especially in the field of image recognition. It consists of 60,000 32x32 color images in 10 different classes, with 6,000 images per class. The dataset is divided into 50,000 training images and 10,000 testing images.
3 min read
Computer Vision Datasets Computer vision has rapidly evolved, impacting sectors from healthcare to automotive and from retail to security. In this article, we delve into the significance of computer vision datasets, explore prominent datasets, and discuss their contributions in shaping the future of AI. These datasets, incl
6 min read
Computer Vision with PyTorch PyTorch is a powerful framework applicable to various computer vision tasks. The article aims to enumerate the features and functionalities within the context of computer vision that empower developers to build neural networks and train models. It also demonstrates how PyTorch framework can be utili
6 min read
Installing a CPU-Only Version of PyTorch PyTorch is a popular open-source machine learning library that provides a flexible platform for developing deep learning models. While PyTorch is well-known for its GPU support, there are many scenarios where a CPU-only version is preferable, especially for users with limited hardware resources or t
3 min read
How to Split a Dataset Using PyTorch Splitting a dataset is an important step in training machine learning models. It helps to separate the data into different sets, typically training, and validation, so we can train our model on one set and validate its performance on another. In this article, we are going to discuss the process of s
6 min read