Convolutional Autoencoder in Pytorch On MNIST Dataset - by Eugenia Anello - DataSeries - Medium
Convolutional Autoencoder in Pytorch On MNIST Dataset - by Eugenia Anello - DataSeries - Medium
Published in DataSeries
Save
Illustration by Author
The post is the seventh in a series of guides to build deep learning models with Pytorch. Below,
there is the full series:
8. Denoising Autoencoder
9. Variational Autoencoder
The goal of the series is to make Pytorch more intuitive and accessible as possible through
examples of implementations. There are many tutorials on the Internet to use Pytorch to build
many types of challenging models, but it can also be confusing at the same time because there
are always slight differences when you pass from a tutorial to another. In this series, I want to
start from the simplest topics to the more advanced ones.
Autoencoder
The autoencoder is an unsupervised deep learning algorithm that learns encoded
representations of the input data and then reconstructs the same input as output. It
consists of two networks, Encoder and Decoder. The Encoder compresses the high-
dimensional input into a low-dimensional latent code, called also latent code or
encoded space, to extract the most relevant information from it, while the Decoder
decompresses the encoded data and recreates the original input.
The goal of this architecture is to maximize the information when encoding and
minimize the reconstruction error. But what is the reconstruction error? Its name is
also reconstruction loss and is usually the mean-squared error between the
reconstructed input and the original input when the input is real-valued. In case the
input data is categorical, the loss function used is the Cross-Entropy Loss.
Implementation in Pytorch
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 2/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 3/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
convolutional layers and two fully connected layers. Some batch norm layers are added
as regularizers. The decoder will have the same architecture but in inverse order.
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 5/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
1 class Encoder(nn.Module):
2
3 def __init__(self, encoded_space_dim,fc2_input_dim):
4 super().__init__()
5
6 ### Convolutional section
7 self.encoder_cnn = nn.Sequential(
8 nn.Conv2d(1, 8, 3, stride=2, padding=1),
9 nn.ReLU(True),
10 nn.Conv2d(8, 16, 3, stride=2, padding=1),
11 nn.BatchNorm2d(16),
12 nn.ReLU(True),
13 nn.Conv2d(16, 32, 3, stride=2, padding=0),
14 nn.ReLU(True)
15 )
16
17 ### Flatten layer
18 self.flatten = nn.Flatten(start_dim=1)
19 ### Linear section
20 self.encoder_lin = nn.Sequential(
21 nn.Linear(3 * 3 * 32, 128),
22 nn.ReLU(True),
23 nn.Linear(128, encoded_space_dim)
24 )
25
26 def forward(self, x):
27 x = self.encoder_cnn(x)
28 x = self.flatten(x)
29 x = self.encoder_lin(x)
30 return x
31 class Decoder(nn.Module):
32
33 def __init__(self, encoded_space_dim,fc2_input_dim):
34 super().__init__()
35 self.decoder_lin = nn.Sequential(
36 nn.Linear(encoded_space_dim, 128),
37 nn.ReLU(True),
38 nn.Linear(128, 3 * 3 * 32),
39 nn.ReLU(True)
40 )
41
42 self.unflatten = nn.Unflatten(dim=1,
43 unflattened_size=(32, 3, 3))
44
45 self decoder conv = nn Sequential(
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 6/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
45 self.decoder_conv = nn.Sequential(
3.46Initialize Loss function and the
nn.ConvTranspose2d(32, optimizer
16, 3,
We
47 need to define the building
stride=2, blocks before training the autoencoder:
output_padding=0),
48 nn.BatchNorm2d(16),
49 torch.device to train the
nn.ReLU(True), model with a hardware accelerator like the GPU
50 nn.ConvTranspose2d(16, 8, 3, stride=2,
51 the and the output_padding=1),
Encoder padding=1, Decoder networks, that will be moved to the device
52 nn.BatchNorm2d(8),
53 nn.MSEloss and
nn.ReLU(True),
torch.optim.Adam
54 nn.ConvTranspose2d(8, 1, 3, stride=2,
55 padding=1, output_padding=1)
1
56 ### Define
) the loss function
2
57 loss_fn = torch.nn.MSELoss()
3
58 def forward(self, x):
4
59 ### Define
x =an optimizer (both for the encoder and the decoder!)
self.decoder_lin(x)
5
60 lr= 0.001
x = self.unflatten(x)
6
61 x = self.decoder_conv(x)
7
62 ### Set x
the
= random seed for reproducible results
torch.sigmoid(x)
8
63 torch.manual_seed(0)
return x
9
.py hosted with ❤ by GitHub view raw
10 ### Initialize the two networks
11 d = 4
12
13 #model = Autoencoder(encoded_space_dim=encoded_space_dim)
14 encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)
15 decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)
16 params_to_optimize = [
17 {'params': encoder.parameters()},
18 {'params': decoder.parameters()}
19 ]
20
21 optim = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=1e-05)
22
23 # Check if the GPU is available
24 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
25 print(f'Selected device: {device}')
26
27 # Move both the encoder and the decoder to the selected device
28 encoder.to(device)
29 decoder.to(device)
We define a function to train the AE model. First, we pass the input images to the
encoder. Later, the encoded data is passed to the decoder and then we compute the
reconstruction loss with loss_fn(x_hat,x) . After we clear the gradient to not
accumulate other values, we perform backpropagation and at the end, we compute the
gradient by calling opt.step() .
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 8/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
We also would like to see the reconstructed images during each epoch of the training.
The goal is to understand how the autoencoder is learning from the input images.
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 9/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
1 def plot_ae_outputs(encoder,decoder,n=10):
2 plt.figure(figsize=(16,4.5))
3 targets = test_dataset.targets.numpy()
4 t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}
5 for i in range(n):
6 ax = plt.subplot(2,n,i+1)
7 img = test_dataset[t_idx[i]][0].unsqueeze(0).to(device)
8 encoder.eval()
9 decoder.eval()
10 with torch.no_grad():
11 rec_img = decoder(encoder(img))
12 plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
13 ax.get_xaxis().set_visible(False)
14 ax.get_yaxis().set_visible(False)
15 if i == n//2:
16 ax.set_title('Original images')
17 ax = plt.subplot(2, n, i + 1 + n)
18 plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')
19 ax.get_xaxis().set_visible(False)
20 ax.get_yaxis().set_visible(False)
21 if i == n//2:
22 ax.set_title('Reconstructed images')
23 plt.show()
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 10/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
Now we can finally begin to train the model on the training set and evaluate it on the
validation set.
1 num_epochs = 30
2 diz_loss = {'train_loss':[],'val_loss':[]}
3 for epoch in range(num_epochs):
4 train_loss =train_epoch(encoder,decoder,device,
5 train_loader,loss_fn,optim)
6 val_loss = test_epoch(encoder,decoder,device,test_loader,loss_fn)
7 print('\n EPOCH {}/{} \t train loss {} \t val loss {}'.format(epoch + 1,
num_epochs,train_loss,val_loss))
8 diz_loss['train_loss'].append(train_loss)
9 diz_loss['val_loss'].append(val_loss)
10 plot_ae_outputs(encoder,decoder,n=10)
It’s possible to notice that the autoencoder is able to reconstruct well the images after
30 epochs, even if there are some imperfections. But since this model is really simple,
it performed very well. Now the model is trained and we want to do a final evaluation
of the test set:
1 test_epoch(encoder,decoder,device,test_loader,loss_fn).item()
We can also observe how the reconstruction losses decrease over the epochs:
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 11/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
1 # Plot losses
2 plt.figure(figsize=(10,8))
3 plt.semilogy(diz_loss['train_loss'], label='Train')
4 plt.semilogy(diz_loss['val_loss'], label='Valid')
5 plt.xlabel('Epoch')
6 plt.ylabel('Average Loss')
7 #plt.grid()
8 plt.legend()
9 #plt.title('loss')
10 plt.show()
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 12/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
1 def show_image(img):
2 npimg = img.numpy()
3 plt.imshow(np.transpose(npimg, (1, 2, 0)))
4
5 encoder.eval()
6 decoder.eval()
7
8 with torch.no_grad():
9 # calculate mean and std of latent code, generated takining in test images as inputs
10 images, labels = iter(test_loader).next()
11 images = images.to(device)
12 latent = encoder(images)
13 latent = latent.cpu()
14
15 mean = latent.mean(dim=0)
16 print(mean)
17 std = (latent - mean).pow(2).mean(dim=0).sqrt()
18 print(std)
19
20 # sample latent vectors from the normal distribution
21 latent = torch.randn(128, d)*std + mean
22
23 # reconstruct images from the random latent vectors
24 latent = latent.to(device)
25 img_recon = decoder(latent)
26 img_recon = img_recon.cpu()
27
28 fig, ax = plt.subplots(figsize=(20, 8.5))
29 show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
30 plt.show()
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 13/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
It should notice that this procedure allows the sampling to be in the same region as the
latent code but at the same time there are digits that make no sense. This aspect is
explained by the fact that the latent space of the autoencoder is extremely irregular:
close points in the latent space can produce very different and meaningless patterns
over visible units. For this reason, the autoencoder doesn’t perform well for generative
purposes.
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 14/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
1 encoded_samples = []
2 for sample in tqdm(test_dataset):
3 img = sample[0].unsqueeze(0).to(device)
4 label = sample[1]
5 # Encode image
6 encoder.eval()
7 with torch.no_grad():
8 encoded_img = encoder(img)
9 # Append to list
10 encoded_img = encoded_img.flatten().cpu().numpy()
11 encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
12 encoded_sample['label'] = label
13 encoded_samples.append(encoded_sample)
14 encoded_samples = pd.DataFrame(encoded_samples)
15 encoded_samples
Let’s plot the latent space representation using plotly express library:
1 import plotly.express as px
2
3 px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1',
4 color=encoded_samples.label.astype(str), opacity=0.7)
Open in app
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 15/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
From this plot, we see that similar digits are clustered together. For example “4”
overlap with “9” and “5”.
You can see that it distinguishes clearly one digit from the other. There are some
exceptions with points that fall in other
184 categories,
4 but t-SNE still remains an
improvement compared to the previous representation.
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 16/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
Final thoughts:
Congratulations! You have learned to implement a Convolutional autoencoder. There
aren’t many tutorials that talk about autoencoders with convolutional layers with
Pytorch, so I wanted to contribute in some way. The autoencoder provides a way to
compress images and extract the most important information. There are also many
extensions of this model to improve the performance, some of these are the Denoising
Autoencoder, the Variational Autoencoder, and the Generative Adversarial Networks.
The GitHub code is here. Thanks for reading. Have a nice day.
Reference:
[1] https://github.com/smartgeometry-ucl/dl4g
Did you like my article? Become a member and get unlimited access to new data science
posts every day! It’s an indirect way of supporting me without any extra cost to you. If you are
already a member, subscribe to get emails whenever I publish new data science and python
guides!
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 17/18
1/17/23, 5:59 PM Convolutional Autoencoder in Pytorch on MNIST dataset | by Eugenia Anello | DataSeries | Medium
Subscribe
https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac 18/18