1.GAN的原理和结构
1.什么是GAN
GAN的全称是Generative adversarial network,中文翻译过来就是生成对抗网络。生成对抗网络其实是两个网络的组合:生成网络(Generator)负责生成模拟数据;判别网络(Discriminator)负责判断输入的数据是真实的还是生成的。生成网络要不断优化自己生成的数据让判别网络判断不出来,判别网络也要优化自己让自己判断得更准确。二者关系形成对抗,因此叫对抗网络。
2.GAN的模型结构
判别器会接受两部分的数据 一部分是真实样本的数据 一部分是生成器生成的假数据
生成器接收噪声之后会生成一个猫的图片 判别器如果预测出真实的就会输出1 如果是假的就会输出0 对于判别器而言 希望将真实的样本判定为1 将生成的样本 判定为0 产生的损失会分别去优化判别器和生成器
3.GAN原理
GAN设计的关键在于损失函数的处理 对于判别模型 损失函数是容易定义的 判别器主要用来判断一张图片是真实的还是生成的 这是一个二分类问题 对于生成模型 损失函数的定义不是那么容易 希望生成器可以生成接近真实的图片 但是对于生成的图片是否像真实的 我们人类肉眼容易判断 但具体到代码中 往往是一个抽象的 难以数学公理化定义的范式 针对这个问题 就把生成模型的输出 交给判别模型的处理 让判别器来判断这是一个真实的图像还是假的图像 因为深度学习模型适合做图片的分类 这样就将生成对抗网络中的两大类模型生成器与判别器紧密联合在一起了
2.GAN的算法流程和公式详解
1.GAN的算法流程
*G是一个生成图片的网络 它接收一个随机的噪声z,通过这个噪声生成图片 叫做G(z)
*D是一个判别网络 判别一张图片是不是真实的 它的输入参数是x x代表一张图片 输出D(x)代表x为真实图片的概率 如果为1 那就代表100%是真实的图片 如果输出为0 那就代表不可能是真实的图片
在训练过程中 将随机噪声输入生成网络G,得到生成的图片; 判别器接收生成的图片和真实的图片 并尽量将两者区分开来 在这个计算过程中 能否正确区分生成的图片和真实的图片将作为判别器的损失 而能否生成近似真实的图片并使得判别器将生成的个图片判定为真将作为生成器的损失
生成器的损失是通过判别器的输出来计算的 而判别器的输出是一个概率值 我们可以通过交叉熵计算
2.GAN公式
公式中x表示真实图片
z表示输入G网络的噪声
G(z)表示G网络生成的图片
D(*)表示D网络判断图片是否真实的概率
*是有可能输入真实的图片有可能输入假的图片
x代表真实的数据(图片) 整个式子表示从真实的图片中随机挑选一张
Pdata表示真实数据的分布
z代表噪声 随即输入一个噪声 一般使用正态分布生成随机噪声 交给模型G(z)来生成一张图片 在交给D(G(z))模型调用 来产生判别器的输出
pz(z)表示生成随机噪声的分布
前半部分
对于判别器D,希望可以正确识别真实的数据 下面公式就是判别器对于真实图片的判断 data是数据集
这一部分的含义就是:判别器判别出真实数据的概率 希望这个概率越大越好 也就是说 对于服从Pdata分布的图片x,判别器应该给出预测结果D(x) = 1
后半部分
对于判别器D来说 如果输入的是生成的数据 也就是D(G(z))
判别器的目标是最小化D(G(z)) 因为判别器想要正确识别出生成的假数据 希望它被判定为0 也就是希望 log这一坨越大越好
对于生成器G来说 希望生成的数据被判别器识别为真 希望它被判定为1 也就是希望D(G(z)) = 1
也就是希望log这一坨越小越好
取对数的原因
对数函数在其定义域内是单调递增函数 数据取对数不改变数据间的相对关系 使用log后 可放大损失 便于计算和优化
3.GAN公式理解
可以看到判别器D 和生成器G 对
的优化目标是相反的 这就体现在公式中的:
3.GAN公式总结和GAN的应用
从判别器D的角度 希望最大化 V(D,G)
从生成器G的角度 希望最小化 V(D,G)
GAN的应用领域
1.图像生成
生成一些假的数据 比如海报中的人脸
2.图像增强
从分割图生成假的真实街景 方便训练无人汽车
3.风格化和艺术的图像创造
转换图像风格 修补图像
4.声音的转换
一个人的声音转为另一个的声音 去除噪音等
4.数据加载
代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
#对数据做归一化处理 -1 - 1 之间 这和Gan训练最后要用tanh函数有关 tanh函数取值范围就是-1 - 1之间
transforms = transforms.Compose([
#0-1归一化:将图像中的像素值从原范围(通常是 [0, 255] 或者具体的设备特定范围)转换到 [0, 1]的范围内,方便神经网络模型训练,加速收敛。
#通道顺序调整:图像通常以 HWC(Height, Width, Channel)的形式存在
#ToTensor() 函数会将通道(color channels)移动到最前面,变成标准的 CHW(Channel, Height, Width)格式,这是PyTorch张量的标准格式。
transforms.ToTensor(),
# mean:均值 std:方差 可以将数据转换成-1 - 1之间的数据
transforms.Normalize(0.5,0.5)
])
#加载内置数据集
train_ds = torchvision.datasets.MNIST(root='E:\Pc项目\pythonProjectlw\数据集',
train=True,
transform=transforms,
download=True)
dataloader = torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
详细说一下代码的作用 1.数据处理 transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(0.5,0.5) ])
对数据做归一化处理 -1 - 1 之间 这和Gan训练最后要用tanh函数有关 tanh函数取值范围就是-1 - 1之间
transforms.ToTensor(): 0-1归一化:将图像中的像素值从原范围(通常是 [0, 255] 或者具体的设备特定范围)转换到 [0, 1] 的范围内,方便神经网络模型训练,加速收敛。 通道顺序调整:图像通常以 HWC(Height, Width, Channel)的形式存在 ToTensor() 函数会将通道(color channels)移动到最前面,变成标准的 CHW(Channel, Height, Width)格式,这是PyTorch张量的标准格式。
transforms.Normalize(0.5,0.5): mean:均值 std:方差 可以将数据转换成-1 - 1之间的数据
2.加载数据集 train_ds = torchvision.datasets.MNIST(root='E:\Pc项目\pythonProjectlw\数据集', train=True, transform=transforms, download=True):
下载并加载 MNIST 数据集的训练集到 train_ds 变量中。
3.封装数据集 dataloader = torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True): 这行代码创建了一个 DataLoader 对象,它封装了 train_ds 数据集,并允许以批量的方式迭代数据 batch_size=64 指定了每个批次包含的数据样本数。这意味着在每次迭代中,DataLoader 将返回 64 个图像和对应的标签 shuffle=True 这个参数指示 DataLoader 在每个训练周期开始时打乱数据集中的样本顺序,这有助于模型学习到更泛化的特征,而不是依赖于特定的数据顺序 防止过拟合
5.生成器模型代码
生成器的输入:
生成器的输入是噪声 随机生成一个长度为100的正态分布的随机数 将它作为噪声输入到生成器模型中 对于生成器 输入就是长度为100的噪声
生成器的输出:
生成的图片(1,28,28大小的 和MNIST图片数据大小一样)
需要定义三个linear层
linear1: 100 --> 256
linear2: 256 --> 512
linear3: 512 --> 28 * 28
reshape: 28*28 --> 1*28*28
#定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.main = nn.Sequential(
nn.Linear(100,256),
nn.ReLU(),
nn.Linear(256,512) ,
nn.ReLU(),
nn.Linear(512,28*28),
#输出层要用tanh激活 不用relu和sigmoid
nn.Tanh()
)
def forwad(self,x):
img = self.main(x)
img = img.view(-1,28,28,1)
return img
img = img.view(-1,28,28,1)
:使用view
方法将扁平的图像表示重塑为4维张量,其形状为(batch_size, 28, 28, 1)
,这对应于一批28x28像素的灰度图像(每个图像有1个颜色通道)
6.判别器模型代码
判别器的输入:
1*28*28的图片
输出:
二分类的概率值 使用sigmod函数 0-1之间的概率值
需要三个linear层和一个sigmoid激活函数层
linear1: 28*28 --> 512
linear2: 512--> 256
linear3: 256 -->1
reshape: 28*28 --> 1
GAN判别器中使用的隐藏层函数为LeakyRelu
区别如下:
ReLU:
定义:ReLU函数是一种非线性激活函数,它将所有的负值设置为0,而正值保持不变。
数学表达式:ReLU(x) = max(0, x)。
Leaky ReLU:
定义:Leaky ReLU是对ReLU的改进,它在输入值小于0时赋予一个非零的斜率,从而避免神经元死亡的问题。
数学表达式:Leaky ReLU(x) = x(当x≥0),αx(当x<0),其中α是一个小的常数斜率。
使用LeakyReLU作为激活函数是为了利用其特性来避免梯度消失、减轻神经元死亡现象以及保持计算简单性。这些优势有助于提高判别器的性能和稳定性,从而推动GAN模型的整体训练效果。
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.main = nn.Sequential(
nn.Linear(28*28,512),
nn.LeakyReLU(),
nn.Linear(512,256) ,
nn.LeakyReLU(),
nn.Linear(256,1),
#输出0-1概率用sigmoid
nn.Sigmoid()
)
def forward(self,x):
#输入是28*28的图片 需要展平 输入到self.main方法中
x = x.view(-1,28*28)
x = self.main(x)
return x
x = x.view(-1,28*28) 将28*28*1的三维图像转换成28*28的二维数据
7.初始化模型,优化器及损失计算函数
#初始化模型 优化器及损失计算函数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#初始化生成器
gen = Generator().to(device)
#初始化判别器
dis = Discriminator().to(device)
#判别器优化器
d_optim = torch.optim.Adam(dis.parameters(), lr = 0.0001)
#生成器优化器
g_optim = torch.optim.Adam(gen.parameters(), lr = 0.0001)
#损失计算函数 二分类计算交叉熵损失的loss函数为BCELoss
loss_fn = torch.nn.BCELoss()
1.设备选择
device = 'cuda' if torch.cuda.is_available() else 'cpu':
这行代码用于选择PyTorch运行的设备。如果CUDA(NVIDIA的并行计算平台和编程模型)可用(即如果系统上有NVIDIA GPU并且安装了CUDA驱动和PyTorch的CUDA支持),则device
将被设置为'cuda'
,否则将被设置为'cpu'
。这允许代码在GPU上加速运行,如果GPU不可用,则回退到CPU。
2.初始化生成器
gen = Generator().to(device):
这行代码创建了一个生成器(Generator
)的实例,并将其移动到之前选择的设备上(GPU或CPU)。.to(device)
方法确保生成器的计算将在正确的设备上执行。
3.初始化判别器
dis = Discriminator().to(device):
与生成器类似,这行代码创建了一个判别器(Discriminator
)的实例,并将其移动到指定的设备上。判别器也是一个自定义的类,通常用于区分真实数据和生成器生成的数据。
4.判别器优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001):
这行代码为判别器创建了一个Adam优化器实例。dis.parameters()
返回判别器模型中所有可训练参数的迭代器,这些参数将被优化器更新以最小化损失函数。lr=0.0001
设置了学习率,这是一个控制参数更新幅度的超参数。
5.生成器优化器
g_optim = torch.optim.Adam(gen.parameters(), lr=0.0001):
与判别器优化器类似,这行代码为生成器创建了一个Adam优化器实例,并使用相同的学习率。
6.损失计算函数
loss_fn = torch.nn.BCELoss():
这行代码创建了一个二分类交叉熵损失函数(BCELoss
)的实例。在生成对抗网络(GAN)中,判别器通常被训练为一个二分类器,用于区分真实数据和生成数据。因此,使用二分类交叉熵损失是合适的。这个损失函数将用于计算判别器的预测与实际标签(真实或生成)之间的差异。
8.绘图函数代码
#定义绘图函数
#model:输入的模型 比如gen或者dis
#test_input:生成器生成模型的时候需要输入一个正态分布随机数
def gen_img_plot(model, test_input):
#预测结果
#squeeze将维度为1的去掉
prediction = np.squeeze(model(test_input).detach().cpu().numpy())
#初始化画布 画16张图片
fig = plt.figure(figsize = (4,4))
for i in range(16):
#4行4列 i从0开始 所以i+1
plt.subplot(4,4,i+1)
#预测输出结果 生成器最后使用的是tanh函数 范围是-1 - 1 无法绘图 需要转换成0-1之间的 所以( +1)/2
plt.imshow((prediction[i]+ 1)/2)
plt.axis('off')
plt.show()
#因为要绘制16张 所以要绘制16*100的张量
test_input = torch.randn(16,100,device=device)
prediction = np.squeeze(model(test_input).detach().cpu().numpy())
model(test_input)
: 将test_input
作为输入传递给模型,得到模型的输出。.detach()
: 从计算图中分离输出,这意味着梯度不会被计算或跟踪,这在生成图像时通常是必要的,因为我们不需要进行反向传播。.cpu()
: 将输出从GPU(如果模型在GPU上运行)移动到CPU,以便后续处理。.numpy()
: 将输出转换为NumPy数组,便于后续处理。np.squeeze()
: 移除数组中所有单维度的条目,这通常用于去除批处理维度(如果输出是一个形状为[1, H, W, C]的数组,squeeze
会将其转换为[H, W, C])。
fig = plt.figure(figsize = (4,4))
- 创建一个新的图形对象
fig
,并设置其大小为4x4英寸。这将用于绘制图像网格。
- 创建一个新的图形对象
for i in range(16):
- 开始一个循环,从0迭代到15,总共迭代16次,对应于将要绘制的16张图像。
plt.subplot(4,4,i+1)
- 使用
subplot
函数在4x4的网格中创建子图,i+1
确保子图的索引从1开始(而不是从0开始),这是matplotlib的惯例。
- 使用
plt.imshow((prediction[i]+ 1)/2)
- 使用
imshow
函数显示图像。由于模型的输出通常通过tanh激活函数产生,其值范围在-1到1之间。为了将其转换为适合绘图的0到1范围,我们对输出进行了线性变换(prediction[i]+ 1)/2
。
- 使用
plt.axis('off')
- 关闭当前子图的坐标轴,使图像显示更加干净,没有坐标轴干扰。
plt.show()
- 显示整个图形对象
fig
,此时所有的子图(即16张图像)都将被渲染和显示出来。
- 显示整个图形对象
8.test_input = torch.randn(16,100,device=device):
-
torch.randn
:用于生成一个形状为指定参数的张量,其元素是从标准正态分布(均值为0,方差为1)中随机采样的 -
(16,100)
: 这是torch.randn
函数的第一个参数,指定了生成张量的形状,张量有两个维度:第一个维度是16,代表要生成的图像数量;第二个维度是100,代表输入特征的数量或模型的输入层的大小。 -
device=device
: 用于指定生成张量的设备
9.GAN的训练
具体含义已经注释
1.判别器的训练
#GAN训练
#记录每次epoch产生的loss值
D_loss = []
G_loss = []
#训练循环 60个epoch
for epoch in range(60):
#初始化损失值为0 将每一个批次产生的损失累加到d_epoch_loss中
#最后除以总的批次数 能够计算出每一个epoch的平均loss
d_epoch_loss = 0
g_epoch_loss = 0
#len(dataloader):返回批次数 len(dataset):返回样本数
count = len(dataloader)
#对dataloader进行迭代会返回一个批次的图片 根据这个图片的数量创建同样数量的noise 作为genertator的输入
for step,(img, _) in enumerate(dataloader):
img = img.to(device)
#获取批次的大小
size = img.size(0)
random_noise = torch.rand(size, 100, device=device)
#判别器训练
#判别器训练过程 有两部分损失 一部分是判断为真的损失 一部分是判断为假的损失
d_optim.zero_grad()
#1.真实的损失
#判别器输入真实的图片 real_output就是对真实图片的预测结果
real_output = dis(img)
#得到判别器在真实图像上的损失
#希望判别器对真实数据的输出real_output的值为1
#这个函数是用来衡量真实输出(real_output)与预期值torch.ones_like(real_output)
# (在这里是与 real_output 形状相同的全1张量)之间的差异
d_real_loss = loss_fn(real_output,
torch.ones_like(real_output))
#反向传播求梯度
d_real_loss.backward()
#2.假的损失
#先将噪声传入生成器生成图片
gen_img = gen(random_noise)
#判别器输入生成的图片 fake_output就是对生成图片的预测结果
#对于生成器产生的损失 优化的目标是判别器 生成器的参数是不需要进行优化的
#当训练判别器时,我们希望仅根据判别器自身的参数和输入(真实图像或生成图像)来计算梯度并更新参数
#如果生成图像的梯度传递到生成器,那么生成器的参数也会在判别器的训练过程中被更新,这不是我们想要的结果
# 因此,我们使用.detach()方法来截断梯度,确保它们不会传播回生成器 .detach()会得到一个没有批次的tensor
fake_output = dis(gen_img.detach())
d_fake_loss = loss_fn(fake_output,
torch.zeros_like(fake_output))
#反向传播求梯度
d_fake_loss.backward()
#损失求和
d_loss = d_real_loss + d_fake_loss
#优化并更新参数
d_optim.step()
2.生成器的训练并记录损失打印
#生成器的损失优化 只包含一部分
#将生成器所有的梯度归零
g_optim.zero_grad()
#将生成器的图片传入到判别器中 没有做梯度的截断
fake_output = dis(gen_img)
#生成器的损失 .ones 因为是对生成器进行优化 生成器想判别为1
g_loss = loss_fn(fake_output,
torch.ones_like(fake_output))
g_loss.backward()
g_optim.step()
#计算在一个epoch当中总共的g_loss和d_loss
with torch.no_grad():
d_epoch_loss += d_loss
g_epoch_loss += g_loss
#计算平均loss值
with torch.no_grad():
d_epoch_loss /= count
g_epoch_loss /= count
#传入到列表当中
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
# 打印信息
print('Epoch:', epoch, 'D Loss:', d_epoch_loss, 'G Loss:', g_epoch_loss)
# 每隔几个epoch绘图(可选) 每隔20个打印一下
if (epoch + 1) % 20 == 0:
gen_img_plot(gen, test_input)
10.完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
#对数据做归一化处理 -1 - 1 之间 这和Gan训练最后要用tanh函数有关 tanh函数取值范围就是-1 - 1之间
transforms = transforms.Compose([
#0-1归一化:将图像中的像素值从原范围(通常是 [0, 255] 或者具体的设备特定范围)转换到 [0, 1] 的范围内,方便神经网络模型训练,加速收敛。
#通道顺序调整:图像通常以 HWC(Height, Width, Channel)的形式存在
#ToTensor() 函数会将通道(color channels)移动到最前面,变成标准的 CHW(Channel, Height, Width)格式,这是PyTorch张量的标准格式。
transforms.ToTensor(),
# mean:均值 std:方差 可以将数据转换成-1 - 1之间的数据
transforms.Normalize(0.5,0.5)
])
#加载内置数据集
train_ds = torchvision.datasets.MNIST(root='E:\Pc项目\pythonProjectlw\数据集',
train=True,
transform=transforms,
download=True)
dataloader = torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
#定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.main = nn.Sequential(
nn.Linear(100,256),
nn.ReLU(),
nn.Linear(256,512) ,
nn.ReLU(),
nn.Linear(512,28*28),
#输出层要用tanh激活 不用relu和sigmoid
nn.Tanh()
)
def forward(self,x):
img = self.main(x)
img = img.view(-1,28,28)
return img
#定义判别器 就是传统的分类模型
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.main = nn.Sequential(
nn.Linear(28*28,512),
nn.LeakyReLU(),
nn.Linear(512,256) ,
nn.LeakyReLU(),
nn.Linear(256,1),
#输出0-1概率用sigmoid
nn.Sigmoid()
)
def forward(self,x):
#输入是28*28的图片 需要展平 输入到self.main方法中
x = x.view(-1,28*28)
x = self.main(x)
return x
#初始化模型 优化器及损失计算函数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#初始化生成器
gen = Generator().to(device)
#初始化判别器
dis = Discriminator().to(device)
#判别器优化器
d_optim = torch.optim.Adam(dis.parameters(), lr = 0.0001)
#生成器优化器
g_optim = torch.optim.Adam(gen.parameters(), lr = 0.0001)
#损失计算函数 二分类计算交叉熵损失的loss函数为BCELoss
loss_fn = torch.nn.BCELoss()
#定义绘图函数
#model:输入的模型 比如gen或者dis
#test_input:生成器生成模型的时候需要输入一个正态分布随机数
def gen_img_plot(model,test_input):
#预测结果
#squeeze将维度为1的去掉
prediction = np.squeeze(model(test_input).detach().cpu().numpy())
#初始化画布 画16张图片
fig = plt.figure(figsize = (4,4))
for i in range(16):
#4行4列 i从0开始 所以i+1
plt.subplot(4,4,i+1)
#预测输出结果 生成器最后使用的是tanh函数 范围是-1 - 1 无法绘图 需要转换成0-1之间的 所以( +1)/2
plt.imshow((prediction[i]+ 1)/2)
plt.axis('off')
plt.show()
#因为要绘制16张 所以要绘制16*100的张量
test_input = torch.randn(16,100,device=device)
#GAN训练
#记录每次epoch产生的loss值
D_loss = []
G_loss = []
#训练循环 20个epoch
for epoch in range(30):
#初始化损失值为0 将每一个批次产生的损失累加到d_epoch_loss中
#最后除以总的批次数 能够计算出每一个epoch的平均loss
d_epoch_loss = 0
g_epoch_loss = 0
#len(dataloader):返回批次数 len(dataset):返回样本数
count = len(dataloader)
#对dataloader进行迭代会返回一个批次的图片 根据这个图片的数量创建同样数量的noise 作为genertator的输入
for step,(img, _) in enumerate(dataloader):
img = img.to(device)
#获取批次的大小
size = img.size(0)
random_noise = torch.rand(size, 100, device=device)
#判别器训练
#判别器训练过程 有两部分损失 一部分是判断为真的损失 一部分是判断为假的损失
d_optim.zero_grad()
#1.真实的损失
#判别器输入真实的图片 real_output就是对真实图片的预测结果
real_output = dis(img)
#得到判别器在真实图像上的损失
#希望判别器对真实数据的输出real_output的值为1
#这个函数是用来衡量真实输出(real_output)与预期值torch.ones_like(real_output)
# (在这里是与 real_output 形状相同的全1张量)之间的差异
d_real_loss = loss_fn(real_output,
torch.ones_like(real_output))
#反向传播求梯度
d_real_loss.backward()
#2.假的损失
#先将噪声传入生成器生成图片
gen_img = gen(random_noise)
#判别器输入生成的图片 fake_output就是对生成图片的预测结果
#对于生成器产生的损失 优化的目标是判别器 生成器的参数是不需要进行优化的
#当训练判别器时,我们希望仅根据判别器自身的参数和输入(真实图像或生成图像)来计算梯度并更新参数
#如果生成图像的梯度传递到生成器,那么生成器的参数也会在判别器的训练过程中被更新,这不是我们想要的结果
# 因此,我们使用.detach()方法来截断梯度,确保它们不会传播回生成器 .detach()会得到一个没有批次的tensor
fake_output = dis(gen_img.detach())
d_fake_loss = loss_fn(fake_output,
torch.zeros_like(fake_output))
#反向传播求梯度
d_fake_loss.backward()
#损失求和
d_loss = d_real_loss + d_fake_loss
#优化并更新参数
d_optim.step()
#生成器的损失优化 只包含一部分
#将生成器所有的梯度归零
g_optim.zero_grad()
#将生成器的图片传入到判别器中 没有做梯度的截断
fake_output = dis(gen_img)
#生成器的损失 .ones 因为是对生成器进行优化 生成器想判别为1
g_loss = loss_fn(fake_output,
torch.ones_like(fake_output))
g_loss.backward()
g_optim.step()
#计算在一个epoch当中总共的g_loss和d_loss
with torch.no_grad():
d_epoch_loss += d_loss
g_epoch_loss += g_loss
#计算平均loss值
with torch.no_grad():
d_epoch_loss /= count
g_epoch_loss /= count
#传入到列表当中
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
# 打印信息
print('Epoch:', epoch, 'D Loss:', d_epoch_loss, 'G Loss:', g_epoch_loss)
# 每隔几个epoch绘图(可选)
if (epoch + 1) % 3 == 0: # some_interval是您选择的间隔
gen_img_plot(gen, test_input)