keras实战-入门之VAE变分自编码器
VAE变分自编码器
原理就不说了,可以看这边文章。
代码是参考keras官方的稍微改了点,加了点注释,最后加了个可交互可视化的东西,方便调整看结果,可以看到插值的图片。
from keras.layers import Lambda, Input, Dense
from keras.models import Model
from keras.datasets import mnist
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model
from keras import backend as K
from keras.callbacks import Callback
import numpy as np
import matplotlib.pyplot as plt
import os
Using TensorFlow backend.
def plot_results(models,
data,
batch_size=128,
model_name="vae_mnist_多层感知机",
image1_name="二维聚类图.png",
image2_name="特征数字图.png"
):
"""Plots labels and MNIST digits as a function of the 2D latent vector
# Arguments
models (tuple): encoder and decoder models
data (tuple): test data and label
batch_size (int): prediction batch size
model_name (string): which model is using this function
"""
encoder, decoder = models
x_test, y_test = data
os.makedirs(model_name, exist_ok=True)
filename = os.path.join(model_name, image1_name)
# 二维特征向量聚类图
z_mean, _, _ = encoder.predict(x_test,
batch_size=batch_size)
plt.figure(figsize=(12, 10))
plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_test)
plt.colorbar()
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.savefig(filename)
plt.show()
filename = os.path.join(model_name, image2_name)
# 30 x 30 900个数字
n = 30
digit_size = 28
#保存图片信息
figure = np.zeros((digit_size * n, digit_size * n))
#间隔取30个点显示
grid_x = np.linspace(-4, 4, n)
#从后向前翻转
grid_y = np.linspace(-4, 4, n)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
#获取z
z_sample = np.array([[xi, yi]])
#复原图像
x_decoded = decoder.predict(z_sample)
#转成图像shape
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
#设定坐标
start_range = digit_size // 2
end_range = (n - 1) * digit_size + start_range + 1
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap='Greys_r')
plt.savefig(filename)
plt.show()
# MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
image_size = x_train.shape[1]
print(image_size)
print(x_train.shape)
28
(60000, 28, 28)
#28 x 28
original_dim = image_size * image_size
x_train = x_train.reshape(-1,original_dim)
x_test = x_test.reshape(-1,original_dim)
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
print(x_train.shape)
print(x_test.shape)
(60000, 784)
(10000, 784)
# 网络参数
input_shape = (original_dim, )
#隐藏层神经元个数
intermediate_dim = 512
#批量数
batch_size = 128
#z向量维度
latent_dim = 2
# 编码器
inputs = Input(shape=input_shape, name='encoder_input')
x = Dense(intermediate_dim, activation='relu')(inputs)
z_mean = Dense(latent_dim, name='z_mean')(x)
#返回的是取log后的,进行计算会简单点,如果直接是方差,还涉及到开方
z_log_var = Dense(latent_dim, name='z_log_var')(x)
WARNING:tensorflow:From F:\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
# 替代z的采样,不然不好求导
# instead of sampling from Q(z|X), sample epsilon = N(0,I)
# z = z_mean + sqrt(var) * epsilon
def sampling(args):
"""Reparameterization trick by sampling from an isotropic unit Gaussian.
# Arguments
args (tensor): mean and log of variance of Q(z|X)
# Returns
z (tensor): sampled latent vector
"""
z_mean, z_log_var = args
print(K.shape(z_mean))
batch = K.shape(z_mean)[0]
dim = K.int_shape(z_mean)[1]
print(K.shape(z_mean),batch,dim)
# by default, random_normal has mean = 0 and std = 1.0
epsilon = K.random_normal(shape=(batch, dim))
return z_mean + K.exp(0.5 * z_log_var) * epsilon
#做了个采样的技巧,即从标准高斯分布采样e,编码器生成的z_mean,z_log_var做运算,输出2维向量
#用了Lambda 自定义一个层
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
Tensor("z/Shape:0", shape=(2,), dtype=int32)
Tensor("z/Shape_2:0", shape=(2,), dtype=int32) Tensor("z/strided_slice:0", shape=(), dtype=int32) 2
# 编码器模型
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoder_input (InputLayer) (None, 784) 0
__________________________________________________________________________________________________
dense_1 (Dense) (None, 512) 401920 encoder_input[0][0]
__________________________________________________________________________________________________
z_mean (Dense) (None, 2) 1026 dense_1[0][0]
__________________________________________________________________________________________________
z_log_var (Dense) (None, 2) 1026 dense_1[0][0]
__________________________________________________________________________________________________
z (Lambda) (None, 2) 0 z_mean[0][0]
z_log_var[0][0]
==================================================================================================
Total params: 403,972
Trainable params: 403,972
Non-trainable params: 0
__________________________________________________________________________________________________
# 解码器
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(intermediate_dim, activation='relu')(latent_inputs)
outputs = Dense(original_dim, activation='sigmoid')(x)
# 解码器模型
decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
z_sampling (InputLayer) (None, 2) 0
_________________________________________________________________
dense_2 (Dense) (None, 512) 1536
_________________________________________________________________
dense_3 (Dense) (None, 784) 402192
=================================================================
Total params: 403,728
Trainable params: 403,728
Non-trainable params: 0
_________________________________________________________________
# 解码器来处理编码器输出的z,进行还原
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae_mlp')
Tensor("encoder/z/Shape:0", shape=(2,), dtype=int32)
Tensor("encoder/z/Shape_2:0", shape=(2,), dtype=int32) Tensor("encoder/z/strided_slice:0", shape=(), dtype=int32) 2
models = (encoder, decoder)
data = (x_test, y_test)
#是否用均方误差
use_mse=True
# 损失函数=输入输出的差异+高斯分布(z_mean,)kl散度
if use_mse:
reconstruction_loss = mse(inputs, outputs)
else:
reconstruction_loss = binary_crossentropy(inputs,
outputs)
reconstruction_loss *= original_dim
#这个其实可以看我的李宏毅GAN学习之VAE,有解释
kl_loss = K.exp(z_log_var)-(1 + z_log_var )+K.square(z_mean)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= 0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')
vae.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
encoder_input (InputLayer) (None, 784) 0
_________________________________________________________________
encoder (Model) [(None, 2), (None, 2), (N 403972
_________________________________________________________________
decoder (Model) (None, 784) 403728
=================================================================
Total params: 807,700
Trainable params: 807,700
Non-trainable params: 0
_________________________________________________________________
F:\Anaconda3\lib\site-packages\ipykernel_launcher.py:23: UserWarning: Output "decoder" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "decoder" during training.
#每次训练完成的回调,打出图像
class Image_show_callback(Callback):
def __init__(self, epochs, name):
self.epochs = epochs
self.name = name
def on_epoch_end(self, epoch, logs={}):
plot_results(models,
data,
batch_size=batch_size,
model_name="vae_多层感知机",
image1_name="二维聚类图_第{}轮.png".format(str(epoch+1)),
image2_name="特征数字图_第{}轮.png".format(str(epoch+1))
)
#训练多少轮
epochs = 50
#是否加载权重参数
load_weights=True
weights='vae_mlp_mnist_weights.h5'
if load_weights:
vae.load_weights(weights)
else:
isc=Image_show_callback(epochs,'vae_mlp')
vae.fit(x_train,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, None),
#每一轮训练完输出图像
# callbacks=[isc]
)
vae.save_weights(weights)
plot_results(models,
data,
batch_size=batch_size)
z_mean, _, _ = encoder.predict(x_test,
batch_size=1)
z_mean.shape
(10000, 2)
z_mean[0].shape
(2,)
print(z_mean[0])
[2.3721828 0.84537846]
x_decoded=decoder.predict(z_mean[0].reshape(1,2))
print(x_decoded.shape)
(1, 784)
plt.imshow(x_decoded.reshape(28,28), cmap='Greys_r')
<matplotlib.image.AxesImage at 0x6a2aa710>
plt.imshow(x_test[0].reshape(28,28), cmap='Greys_r')
<matplotlib.image.AxesImage at 0x631cc438>
from ipywidgets import *
def decoder_result(x,y):
x_decoded=decoder.predict(np.array([[x,y]]))
return x_decoded.reshape(28,28)
#定义了二维向量,根据定义的plot_results方法,从-4到4 采样,可以进行向量的交互调整,可以看到变化的过程。即中间的插值图片
@interact(x=(-4, 4, 0.1),y=(-4, 4, 0.1),continuous_update=False)
def visualize_gradient_descent(x=2.3721828,y=0.84537846):
plt.imshow(decoder_result(x,y), cmap='Greys_r')
好了,今天就到这里了,希望对学习理解有帮助,大神看见勿喷,仅为自己的学习理解,能力有限,请多包涵,侵删。