Building a Vision Transformer from Scratch in PyTorch
Last Updated :
03 Oct, 2024
Vision Transformers (ViTs) have revolutionized the field of computer vision by leveraging transformer architecture, which was originally designed for natural language processing. Unlike traditional CNNs, ViTs divide an image into patches and treat them as tokens, allowing the model to learn spatial relationships effectively. In this tutorial, we’ll walk through building a Vision Transformer from scratch using PyTorch, from setting up the environment to fine-tuning the model.
A Vision Transformer (ViT) is a deep learning architecture designed to apply transformers to computer vision tasks. Traditionally, convolutional neural networks (CNNs) have been the dominant model for vision-based applications, but ViTs offer a novel approach. Instead of using convolutions to process images, ViTs split an image into smaller patches and treat each patch as a token (similar to words in NLP), feeding them into a transformer model. The ViT model captures long-range dependencies in an image, making it particularly effective for tasks like image classification.
- Patch Embedding: Instead of processing the entire image at once, Vision Transformers divide an image into non-overlapping patches and embed them into a lower-dimensional space.
- Positional Encoding: Since transformers are permutation-invariant, positional information is crucial for them to understand spatial relationships between image patches.
- Multi-Head Self-Attention: This mechanism allows the model to attend to multiple patches at once, capturing both local and global dependencies.
- Feed-Forward Network: A multi-layer perceptron (MLP) processes the attention-weighted patches to generate the final output.
Transformers have proven highly effective in natural language processing (NLP), particularly in tasks requiring attention mechanisms. By applying transformers to vision tasks, we can overcome some of the limitations of CNNs:
- Global Attention: Transformers excel at capturing global dependencies through self-attention, allowing them to focus on the entire image rather than just local regions, as CNNs do.
- Scalability: Vision Transformers are highly scalable, performing well when pre-trained on large datasets (e.g., ImageNet).
- Fewer Inductive Biases: Unlike CNNs, which rely heavily on the locality of pixels, transformers make fewer assumptions about the input, potentially learning more complex relationships in the data.
Let's implement an code for Building a Vision Transformer from Scratch in PyTorch, including patch embedding, positional encoding, multi-head attention, transformer encoder blocks, and training on the CIFAR-10 dataset. Below is a step-by-step guide to building a Vision Transformer using PyTorch.
1. Dividing the Image into Patches
Vision Transformers first divide an image into fixed-size patches. Each patch is flattened into a vector, which is then embedded using a linear projection.
Python
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
2. Adding Positional Embeddings
Since transformers don’t have a built-in sense of order, we need to add positional information to each patch to capture the spatial relationships.
Python
# 2. Adding Positional Embeddings
class PositionalEncoding(nn.Module):
def __init__(self, embed_dim, seq_len):
super().__init__()
self.pos_embed = nn.Parameter(torch.randn(1, seq_len + 1, embed_dim)) # Adjusted for [CLS] token
def forward(self, x):
return x + self.pos_embed
3. Defining the Multi-Head Self-Attention Mechanism
Multi-head self-attention allows the model to focus on different parts of the image simultaneously, capturing both local and global features.
Python
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
def forward(self, x):
return self.attn(x, x, x)[0]
A full Transformer encoder block consists of a multi-head self-attention layer, followed by a feed-forward network and residual connections.
Python
class TransformerEncoderBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_dim):
super().__init__()
self.attn = MultiHeadAttention(embed_dim, num_heads)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_dim),
nn.ReLU(),
nn.Linear(mlp_dim, embed_dim)
)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
Finally, we can stack the transformer blocks and define the Vision Transformer model. We will also add a classification head at the end.
Python
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, num_classes=10, embed_dim=768, num_heads=8, depth=6, mlp_dim=1024):
super().__init__()
self.patch_embedding = PatchEmbedding(img_size, patch_size, 3, embed_dim)
self.pos_encoding = PositionalEncoding(embed_dim, (img_size // patch_size) ** 2)
self.transformer_blocks = nn.ModuleList([
TransformerEncoderBlock(embed_dim, num_heads, mlp_dim) for _ in range(depth)
])
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.mlp_head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
B = x.size(0)
x = self.patch_embedding(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_encoding(x)
for block in self.transformer_blocks:
x = block(x)
return self.mlp_head(x[:, 0])
To train the model, we can use a simple dataset such as CIFAR-10, and define a training loop.
Python
import torch.optim as optim
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
model = VisionTransformer()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(5): # Train for 5 epochs
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.cuda(), labels.cuda() # Move to GPU if available
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch [{epoch+1}/5], Loss: {running_loss/len(train_loader)}")
Output:
Files already downloaded and verified
Epoch [1/5], Loss: 2.761860250130115
Epoch [2/5], Loss: 2.3324048172870815
Epoch [3/5], Loss: 2.324295696965106
Epoch [4/5], Loss: 2.3209078250904533
Epoch [5/5], Loss: 6.058996846106902
After running this implementation on CIFAR-10 for 5 epochs, we can see the loss decreasing each epoch, indicating that the model is learning.
Conclusion
In conclusion, building a Vision Transformer (ViT) from scratch using PyTorch involves understanding the key components of transformer architecture, such as patch embedding, self-attention, and positional encoding, and applying them to vision tasks. By training the model on datasets like CIFAR-10, we can leverage the power of transformers in computer vision. While the implementation may seem complex, ViTs provide a highly effective alternative to traditional CNNs, particularly for tasks that require capturing long-range dependencies within an image. Fine-tuning and optimization further enhance the model's performance.