1. 描述
glow 模型的优势在于能够自定义模型的forward前向计算,和 自定义reverse 后向运算,论文太精彩了,匆匆笔记,代码也不完全,后续更新吧。代码跑不起来,后续优化,但glow论文的思路非常的优秀。并且应用了将矩阵按照列分割的方式进行拟合,简直就是 麻省理工学院教授吉尔伯特・斯特朗(Gilbert Strang)老爷子的线性代数的思想再现,后续论文需要重构,用更好的结构和思路优化才行。
2. pytorch【运行不起来,后续优化吧】
import torch
import torch.nn as nn
from torch.nn import functional as F
from math import log, pi, exp
import numpy as np
from scipy import linalg as la
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logabs = lambda x: torch.log(torch.abs(x))
def gaussian_log_p(x, mean, log_sd):
return -0.5 * log(2 * pi) - log_sd - 0.5 * (x - mean) ** 2 / torch.exp(2 * log_sd)
def gaussian_sample(eps, mean, log_sd):
"""
// function: 从高斯分布中采用,其实就是参数重整化
:param eps:
:param mean:
:param log_sd:
:return:
"""
return mean + torch.exp(log_sd) * eps
class ActNorm(nn.Module):
def __init__(self, in_channel, logdet=True):
super(ActNorm, self).__init__()
self.in_channel = in_channel
self.logdet = logdet
# 只对channel维度进行运算,也就是本质上loc,scale 只有 num_channels个数值
self.loc = nn.Parameter(torch.zeros(1, self.in_channel, 1, 1)) # 因为数据格式是3S*CH*H*W
self.scale = nn.Parameter(torch.ones(1, self.in_channel, 1, 1))
self.initialized = nn.Parameter(torch.tensor(0, dtype=torch.uint8), requires_grad=False)
def initialized(self, input):
with torch.no_grad():
# 此处的mean,std 可以通过 torch.mean(input,(0,2,3),keepdim=True来实现
mean = torch.mean(input, dim=(0, 2, 3), keepdim=True)
std = torch.mean(input, dim=(0, 2, 3), keepdim=True)
# 数据依赖的初始化,目的是第一批次的数据经过actnorm后编程标准分布,以稳定训练收敛
# mean-variance normalization
# copy_ 原地赋值某个张量
self.loc.data.copy_(-mean)
self.scale.data.copy_(1 / (std + 1e-6))
def forward(self, input):
batch_size, _, height, width = input.shape
# print(f"get device of input1:",input.get_device())
# item 返回的是python格式的标量,只对标量有效
if self.initialized.item() == 0:
self.initialized(input)
# fill_表示填充特定的值
self.initialized.fill_(1)
log_abs = logabs(self.scale)
# 对数似然的变化量,是一个标量
logdet = height * width * torch.sum(log_abs)
# print(f"get device of logdet1:",logdet.get_device())
# 扩展成 batch_size 个
logdet = torch.tile(torch.tensor([logdet], device=device), (batch_size,))
# print(f"get device of logdet2",logdet.get_device())
# 前向运算
if self.logdet:
return self.scale * (input + self.loc), logdet
else:
return self.scale * (input + self.loc)
def reverse(self, output):
return output / self.scale - self.loc
class InvConv2d(nn.Module):
def __init__(self, in_channel):
super(InvConv2d, self).__init__()
self.in_channel = in_channel
weight = torch.randn(self.in_channel, self.in_channel)
q, _ = torch.qr(weight)
weight = q.unsqueeze(2).unsqueeze(3)
self.weight = nn.Parameter(weight)
def forward(self, input):
batch_size, _, height, width = input.shape
out = F.conv2d(input, self.weight)
logdet = (
height * width * torch.slogdet(self.weight.squeeze().double())[1].float()
)
return out, logdet
def reverse(self, output):
return F.conv2d(
output, self.weight.squeeze().inverse().unsqueeze(2).unsqueeze(3)
)
class InvConv2dLU(nn.Module):
def __init__(self, in_channel):
super(InvConv2dLU, self).__init__()
weight = np.random.randn(in_channel, in_channel)
# QR分解,Q为正交矩阵,R为上三角矩阵与0矩阵拼接起来的
# 任何矩阵都可以进行QR分解
q, _ = la.qr(weight)
# 行列式不为0,则LU分解一定存在
# Q为正交矩阵,行列式值为+1或-1,故一定可以进行LU分解
# A = PLU,P为置换矩阵,L为下三角,U为上三角
# 根据scipy教程,这里的L一定对角线元素为1
w_p, w_l, w_u = la.lu(q.astype(np.float32))
# 获取对角线元素构成数组,一维向量
w_s = np.diag(w_u, 1)
# 只保留第1条对角线的上三角矩阵,其实就是对角线元素变成了0
w_u = np.triu(w_u, 1)
# 只保留第1条对角线的上三角矩阵,其实就是对角线元素变成了0,而其他元素全为1
u_mask = np.triu(np.ones_like(w_u), 1)
# 只保留第-1条对角线的下三角矩阵,而其他元素全部为1
l_mask = u_mask.T
w_p = torch.from_numpy(w_p).to(device)
w_l = torch.from_numpy(w_l).to(device)
w_s = torch.from_numpy(w_s).to(device)
w_u = torch.from_numpy(w_u).to(device)
self.register_buffer("w_p", w_p)
self.register_buffer("u_mask", torch.from_numpy(u_mask).to(device))
self.register_buffer("l_mask", torch.from_numpy(l_mask).to(device))
# sign,[-1,0,1] 三者中取其一
self.register_buffer("s_sign", torch.sign(w_s))
# 得到全为1为对角线的对角矩阵
self.register_buffer("l_eye", torch.sign(w_s))
self.w_l = nn.Parameter(w_l)
self.w_s = nn.Parameter(logabs(w_s))
self.w_u = nn.Parameter(w_u)
def forward(self, input):
batch_size, _, height, width = input.shape
# 维度 为 [IC,IC,1,1]
weight = self.calc_weight()
out = F.conv2d(input, weight)
# logdet是一个标量
logdet = height * width * torch.sum(self.w_s)
logdet = torch.tile(torch.tensor([logdet], device=device), (batch_size,))
return out, logdet
def calc_weight(self):
weight = (
self.w_p
@ (self.w_l + self.l_mask + self.l_eye) # 只所以这样做是因为w_l是随着训练的进行在变化
@ ((self.w_u + self.u_mask) + torch.diag(self.s_sign + torch.exp(self.w_s)))
)
# 转成 4D tensor
return weight.unsqueeze(2).unsqueeze(3)
def forward(self, input):
batch_size, _, height, width = input.shape
# 维度为[IC,IC,1,1]
weight = self.calc_weight()
out = F.conv2d(input, weight)
# logdet是一个标量
logdet = height * width * torch.sum(self.w_s)
logdet = torch.tile(torch.tensor([logdet], device=device), (batch_size,))
return out, logdet
def reverse(self, output):
weight = self.calc_weight()
print(torch.det(weight.squeeze()))
return F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))
class ZeroConv2d(nn.Module):
def __init__(self, in_channel, out_channel, padding=1):
super(ZeroConv2d, self).__init__()
self.conv = nn.Conv2d(in_channel, out_channel, 3, padding=0)
# 全零初始化
self.conv.weight.data.zero_()
self.conv.bias.data.zero_()
# 形状为 [BS,OC,H,W],可以广播机制
# 这里相当于是一个可以学习的channel scale,是对每个同奥进行单独缩放的
# 也是全零初始化
self.scale = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
def forward(self, input):
out = F.pad(input, [1, 1, 1, 1], value=1)
out = self.conv(out)
out = out + torch.exp(self.scale + 3)
return out
class AffineCoupling(nn.Module):
def __init__(self, in_channel, filter_size=512, affine=True):
super(AffineCoupling, self).__init__()
self.affine = affine
self.net = nn.Sequential(
nn.Conv2d(in_channel // 2, filter_size, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(filter_size, filter_size, 1),
nn.ReLU(inplace=True),
# 全零初始化,这样网络一开始可以当做一个恒等变换
# affine,是为了保证能输出两部分,分别是logs和t
ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2), )
self.net[0].weight.data.normal_(0, 0.05)
self.net[0].bias.data.zero_()
self.net[2].weight.data.normal_(0, 0.05)
self.net[2].bias.data.zero_()
def forward(self, input):
# 在通道维度切分为a,b两份
in_a, in_b = input.chunk(2, 1)
if self.affine:
# 经过网络输出log_s 和 t
log_s, t = self.net(in_a).chunk(2, 1)
# print(f"get device of log_s:",log_s.get_device())
# s= torch.exp(log_s)
# sigmoid 更稳定训练
s = F.sigmoid(log_s + 2)
# print(f"get device of s:",s.get_device())
# out_a = s*in_a + t
out_b = (in_b + t) * s
# s的形状为[BS,ch//2,h,w]
# 因为这里的s与batch_size有关,所以logdet的形状是[BS],而不再是变量
logdet = torch.sum(torch.log(s).view(input.shape[0], -1), 1)
else:
net_out = self.net(in_a)
out_b = in_b + net_out
logdet = None
# 因为 in_a = out_a
# 最后再把 a,b 拼起来
return torch.cat([in_a, out_b], 1), logdet
def reverse(self, output):
out_a, out_b = output.chunk(2, 1)
if self.affine:
# 逆运算
log_s, t = self.net(out_a).chunk(2, 1)
# s = torch.exp(log_s)
s = F.sigmoid(log_s + 2)
# in_a = (out_a-t)/s
in_b = out_b / s - t
else:
net_out = self.net(out_a)
in_b = out_b - net_out
# 因为 in_a = out_a
# 最后再把a,b拼起来
return torch.cat([out_a, in_b], 1)
class Flow(nn.Module):
def __init__(self, in_channel, affine=True, conv_lu=True):
super(Flow, self).__init__()
self.actnorm = ActNorm(in_channel)
if conv_lu:
self.invconv = InvConv2dLU(in_channel)
else:
self.invconv = InvConv2d(in_channel)
self.coupling = AffineCoupling(in_channel, affine=affine)
def forward(self, input):
out, logdet = self.actnorm(input)
# print(f"get device of logdet:",logdet.get_device())
out, det1 = self.invconv(out)
out, det2 = self.coupling(out)
# print(f"get device of det1:",det1.get_device())
# print(f"get device of det2:",det2.get_device())
logdet = logdet + det1
if det2 is not None:
logdet = logdet + det2
return out, logdet
def reverse(self, output):
input = self.coupling.reverse(output)
input = self.invconv.reverse(input)
input = self.actnorm.reverse(input)
return input
class Block(nn.Module):
"""
每个Block 包含多个Flow
"""
def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True):
super(Block, self).__init__()
squeeze_dim = in_channel * 4
self.flows = nn.ModuleList()
for i in range(n_flow):
self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))
self.split = split
if split:
self.prior = ZeroConv2d(in_channel * 2, in_channel * 4)
else:
self.prior = ZeroConv2d(in_channel * 4, in_channel * 8)
self.h_zero = nn.Parameter(torch.zeros(1, in_channel * 4, 8, 8), requires_grad=True)
self.label_embedding = nn.Embedding(40, 32)
self.proj_layer = nn.Linear(32, in_channel * 4)
def forward(self, input, label):
b_size, n_channel, height, width = input.shape
# 对输入的通道数进行扩增,对空间进行缩小,4倍的比例
squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2)
squeezed = squeezed.permute(0, 1, 3, 5, 2, 4)
out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2)
logdet = 0
for flow in self.flows:
# 每个flow是各向同性的,输出形状与输入形状一致
out, det = flow(out)
logdet = logdet + det
if self.split:
out, z_new = out.chunk(2, 1)
mean, log_sd = self.prior(out).chunk(2, 1)
log_p = gaussian_log_p(z_new, mean, log_sd)
log_p = log_p.view(b_size, -1).sum(1)
else:
condition = self.label_embedding(label).sum(1) # [BS,dim]
condition = F.softplus(self.proj_layer(condition))
condition = condition.unsqueeze(-1).unsqueeze(-1)
mean, log_sd = self.prior(self.h_zero + condition).chunk(2, 1)
log_p = gaussian_log_p(out, mean, log_sd)
log_p = log_p.view(b_size, -1).sum(1)
z_new = out
return out, logdet, log_p, z_new
def reverse(self, output, label, eps=None, reconstruct=False):
input = output
if reconstruct:
if self.split:
input = torch.cat([output, eps], 1)
else:
input = eps
else:
if self.split:
mean, log_sd = self.prior(input).chunk(2, 1)
z = gaussian_sample(eps, mean, log_sd)
input = torch.cat([output, z], 1)
else:
condition = self.label_embedding(label) # [bs,dim]
condition = F.softplus(self.proj_layer(condition))
condition = condition.unsqueeze(-1).unsqueeze(-1)
mean, log_sd = self.prior(self.h_zero + condition).chunk(2, 1)
z = gaussian_sample(eps, mean, log_sd)
input = z
for flow in self.flows[::-1]:
input = flow.reverse(input)
b_size, n_channel, height, width = input.shape
# 从 [bs,4*c,h/2,w/2]还原成[bs,c,h,w]
unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)
unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)
unsqueezed = unsqueezed.contiguous().view(
b_size, n_channel // 4, height * 2, width * 2
)
return unsqueezed
class Glow(nn.Module):
def __init__(
self, in_channel, n_flow, n_block, affine=True, conv_lu=True):
super(Glow, self).__init__()
self.blocks = nn.ModuleList()
n_channel = in_channel
for i in range(n_block - 1):
self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))
n_channel *= 2
self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))
if conv_lu:
print("conv_lu=True")
else:
print("conv_lu=False")
print("apply classifier_net for last latent z vector")
self.classifier_net = nn.Sequential(
nn.Flatten(1),
nn.Linear(48 * 8 * 8, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 40),
)
def forward(self, input, label):
# print(f"get device of input:",input.get_device())
log_p_sum = 0
logdet = 0
out = input
z_outs = []
for block in self.blocks:
out, det, log_p, z_new = block(out, label)
z_outs.append(z_new)
logdet = logdet + det
if log_p is not None:
log_p_sum = log_p_sum + log_p
logits = self.classifier_net(z_new)
return log_p_sum, logdet, z_outs, logits
def reverse(self, z_list, label, reconstruct=False):
for i, block in enumerate(self.blocks[::-1]):
if i == 0:
# 最后一个block的z=out=input
input = block.reverse(z_list[-1], label, z_list[-1], reconstruct=reconstruct)
else:
# 前面每个block的input=out等于输出
input = block.reverse(input, label, z_list[-(i + 1)], reconstruct=reconstruct)
return input