keras实战-入门之VAE变分自编码器

本文介绍使用Keras实现变分自编码器(VAE)的过程,通过Mnist手写数字集训练模型,实现图像重构与可视化。文中详细展示了VAE的编码器与解码器构建,损失函数计算,以及如何利用VAE进行图像插值和聚类分析。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

keras实战-入门之VAE变分自编码器

VAE变分自编码器

原理就不说了,可以看这边文章
代码是参考keras官方的稍微改了点,加了点注释,最后加了个可交互可视化的东西,方便调整看结果,可以看到插值的图片。

from keras.layers import Lambda, Input, Dense
from keras.models import Model
from keras.datasets import mnist
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model
from keras import backend as K
from keras.callbacks import Callback
import numpy as np
import matplotlib.pyplot as plt
import os
Using TensorFlow backend.
def plot_results(models,
                 data,
                 batch_size=128,
                 model_name="vae_mnist_多层感知机",
                image1_name="二维聚类图.png",
                image2_name="特征数字图.png"
                ):
    """Plots labels and MNIST digits as a function of the 2D latent vector
    # Arguments
        models (tuple): encoder and decoder models
        data (tuple): test data and label
        batch_size (int): prediction batch size
        model_name (string): which model is using this function
    """

    encoder, decoder = models
    x_test, y_test = data
    os.makedirs(model_name, exist_ok=True)

    filename = os.path.join(model_name, image1_name)
    # 二维特征向量聚类图
    z_mean, _, _ = encoder.predict(x_test,
                                   batch_size=batch_size)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_test)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.savefig(filename)
    plt.show()

    filename = os.path.join(model_name, image2_name)
    # 30 x 30 900个数字
    n = 30
    digit_size = 28
    #保存图片信息
    figure = np.zeros((digit_size * n, digit_size * n))
    #间隔取30个点显示
    grid_x = np.linspace(-4, 4, n)
    #从后向前翻转
    grid_y = np.linspace(-4, 4, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            #获取z
            z_sample = np.array([[xi, yi]])
            #复原图像
            x_decoded = decoder.predict(z_sample)
            #转成图像shape
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size,
                   j * digit_size: (j + 1) * digit_size] = digit

    plt.figure(figsize=(10, 10))
    #设定坐标
    start_range = digit_size // 2
    end_range = (n - 1) * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.savefig(filename)
    plt.show()
# MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
image_size = x_train.shape[1]
print(image_size)
print(x_train.shape)
28
(60000, 28, 28)
#28 x 28
original_dim = image_size * image_size
x_train = x_train.reshape(-1,original_dim)
x_test = x_test.reshape(-1,original_dim)
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
print(x_train.shape)
print(x_test.shape)
(60000, 784)
(10000, 784)
# 网络参数
input_shape = (original_dim, )
#隐藏层神经元个数
intermediate_dim = 512
#批量数
batch_size = 128
#z向量维度
latent_dim = 2

# 编码器
inputs = Input(shape=input_shape, name='encoder_input')
x = Dense(intermediate_dim, activation='relu')(inputs)
z_mean = Dense(latent_dim, name='z_mean')(x)
#返回的是取log后的,进行计算会简单点,如果直接是方差,还涉及到开方
z_log_var = Dense(latent_dim, name='z_log_var')(x)

WARNING:tensorflow:From F:\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
# 替代z的采样,不然不好求导
# instead of sampling from Q(z|X), sample epsilon = N(0,I)
# z = z_mean + sqrt(var) * epsilon
def sampling(args):
    """Reparameterization trick by sampling from an isotropic unit Gaussian.
    # Arguments
        args (tensor): mean and log of variance of Q(z|X)
    # Returns
        z (tensor): sampled latent vector
    """
    
    z_mean, z_log_var = args
    print(K.shape(z_mean))
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    print(K.shape(z_mean),batch,dim)
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

#做了个采样的技巧,即从标准高斯分布采样e,编码器生成的z_mean,z_log_var做运算,输出2维向量
#用了Lambda 自定义一个层
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
Tensor("z/Shape:0", shape=(2,), dtype=int32)
Tensor("z/Shape_2:0", shape=(2,), dtype=int32) Tensor("z/strided_slice:0", shape=(), dtype=int32) 2
# 编码器模型
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
encoder_input (InputLayer)      (None, 784)          0                                            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 512)          401920      encoder_input[0][0]              
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 2)            1026        dense_1[0][0]                    
__________________________________________________________________________________________________
z_log_var (Dense)               (None, 2)            1026        dense_1[0][0]                    
__________________________________________________________________________________________________
z (Lambda)                      (None, 2)            0           z_mean[0][0]                     
                                                                 z_log_var[0][0]                  
==================================================================================================
Total params: 403,972
Trainable params: 403,972
Non-trainable params: 0
__________________________________________________________________________________________________
# 解码器
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(intermediate_dim, activation='relu')(latent_inputs)
outputs = Dense(original_dim, activation='sigmoid')(x)
# 解码器模型
decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
z_sampling (InputLayer)      (None, 2)                 0         
_________________________________________________________________
dense_2 (Dense)              (None, 512)               1536      
_________________________________________________________________
dense_3 (Dense)              (None, 784)               402192    
=================================================================
Total params: 403,728
Trainable params: 403,728
Non-trainable params: 0
_________________________________________________________________
# 解码器来处理编码器输出的z,进行还原
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae_mlp')
Tensor("encoder/z/Shape:0", shape=(2,), dtype=int32)
Tensor("encoder/z/Shape_2:0", shape=(2,), dtype=int32) Tensor("encoder/z/strided_slice:0", shape=(), dtype=int32) 2

models = (encoder, decoder)
data = (x_test, y_test)
#是否用均方误差
use_mse=True

# 损失函数=输入输出的差异+高斯分布(z_mean,)kl散度
if use_mse:
    reconstruction_loss = mse(inputs, outputs)
else:
    reconstruction_loss = binary_crossentropy(inputs,
                                              outputs)

reconstruction_loss *= original_dim
#这个其实可以看我的李宏毅GAN学习之VAE,有解释
kl_loss = K.exp(z_log_var)-(1 + z_log_var )+K.square(z_mean)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= 0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')
vae.summary()



_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
encoder_input (InputLayer)   (None, 784)               0         
_________________________________________________________________
encoder (Model)              [(None, 2), (None, 2), (N 403972    
_________________________________________________________________
decoder (Model)              (None, 784)               403728    
=================================================================
Total params: 807,700
Trainable params: 807,700
Non-trainable params: 0
_________________________________________________________________


F:\Anaconda3\lib\site-packages\ipykernel_launcher.py:23: UserWarning: Output "decoder" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "decoder" during training.
#每次训练完成的回调,打出图像
class Image_show_callback(Callback):
    def __init__(self, epochs, name):
        self.epochs = epochs
        self.name = name
    def on_epoch_end(self, epoch, logs={}):
         plot_results(models,
         data,
         batch_size=batch_size,
         model_name="vae_多层感知机",
        image1_name="二维聚类图_第{}轮.png".format(str(epoch+1)),
        image2_name="特征数字图_第{}轮.png".format(str(epoch+1))
        )



#训练多少轮
epochs = 50
#是否加载权重参数
load_weights=True
weights='vae_mlp_mnist_weights.h5'
if load_weights:
    vae.load_weights(weights)
else:
    isc=Image_show_callback(epochs,'vae_mlp')
    vae.fit(x_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(x_test, None),
            #每一轮训练完输出图像
#            callbacks=[isc]
           )
    vae.save_weights(weights)

plot_results(models,
         data,
         batch_size=batch_size)
       

在这里插入图片描述
在这里插入图片描述

z_mean, _, _ = encoder.predict(x_test,
                                   batch_size=1)
z_mean.shape
(10000, 2)
z_mean[0].shape
(2,)
print(z_mean[0])
[2.3721828  0.84537846]
x_decoded=decoder.predict(z_mean[0].reshape(1,2))
print(x_decoded.shape)
(1, 784)
plt.imshow(x_decoded.reshape(28,28), cmap='Greys_r')
<matplotlib.image.AxesImage at 0x6a2aa710>

在这里插入图片描述

plt.imshow(x_test[0].reshape(28,28), cmap='Greys_r')
<matplotlib.image.AxesImage at 0x631cc438>

在这里插入图片描述

from ipywidgets import *
def decoder_result(x,y):
    x_decoded=decoder.predict(np.array([[x,y]]))
    return x_decoded.reshape(28,28)
#定义了二维向量,根据定义的plot_results方法,从-4到4 采样,可以进行向量的交互调整,可以看到变化的过程。即中间的插值图片
@interact(x=(-4, 4, 0.1),y=(-4, 4, 0.1),continuous_update=False)
def visualize_gradient_descent(x=2.3721828,y=0.84537846):
    plt.imshow(decoder_result(x,y), cmap='Greys_r')

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

好了,今天就到这里了,希望对学习理解有帮助,大神看见勿喷,仅为自己的学习理解,能力有限,请多包涵,侵删。

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值