Glow 论文随手笔记

1. 描述

glow 模型的优势在于能够自定义模型的forward前向计算,和 自定义reverse 后向运算,论文太精彩了,匆匆笔记,代码也不完全,后续更新吧。代码跑不起来,后续优化,但glow论文的思路非常的优秀。并且应用了将矩阵按照列分割的方式进行拟合,简直就是 麻省理工学院教授吉尔伯特・斯特朗(Gilbert Strang)老爷子的线性代数的思想再现,后续论文需要重构,用更好的结构和思路优化才行。

在这里插入图片描述

2. pytorch【运行不起来,后续优化吧】

import torch
import torch.nn as nn
from torch.nn import functional as F
from math import log, pi, exp
import numpy as np
from scipy import linalg as la

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logabs = lambda x: torch.log(torch.abs(x))


def gaussian_log_p(x, mean, log_sd):
    return -0.5 * log(2 * pi) - log_sd - 0.5 * (x - mean) ** 2 / torch.exp(2 * log_sd)


def gaussian_sample(eps, mean, log_sd):
    """
    // function: 从高斯分布中采用,其实就是参数重整化
    :param eps:
    :param mean:
    :param log_sd:
    :return:
    """
    return mean + torch.exp(log_sd) * eps


class ActNorm(nn.Module):
    def __init__(self, in_channel, logdet=True):
        super(ActNorm, self).__init__()
        self.in_channel = in_channel
        self.logdet = logdet
        # 只对channel维度进行运算,也就是本质上loc,scale 只有 num_channels个数值
        self.loc = nn.Parameter(torch.zeros(1, self.in_channel, 1, 1))  # 因为数据格式是3S*CH*H*W
        self.scale = nn.Parameter(torch.ones(1, self.in_channel, 1, 1))
        self.initialized = nn.Parameter(torch.tensor(0, dtype=torch.uint8), requires_grad=False)

    def initialized(self, input):
        with torch.no_grad():
            # 此处的mean,std 可以通过 torch.mean(input,(0,2,3),keepdim=True来实现
            mean = torch.mean(input, dim=(0, 2, 3), keepdim=True)
            std = torch.mean(input, dim=(0, 2, 3), keepdim=True)
            # 数据依赖的初始化,目的是第一批次的数据经过actnorm后编程标准分布,以稳定训练收敛
            # mean-variance normalization
            # copy_ 原地赋值某个张量
            self.loc.data.copy_(-mean)
            self.scale.data.copy_(1 / (std + 1e-6))

    def forward(self, input):
        batch_size, _, height, width = input.shape
        # print(f"get device of input1:",input.get_device())
        # item 返回的是python格式的标量,只对标量有效
        if self.initialized.item() == 0:
            self.initialized(input)
            # fill_表示填充特定的值
            self.initialized.fill_(1)
        log_abs = logabs(self.scale)
        # 对数似然的变化量,是一个标量
        logdet = height * width * torch.sum(log_abs)
        # print(f"get device of logdet1:",logdet.get_device())

        # 扩展成 batch_size 个
        logdet = torch.tile(torch.tensor([logdet], device=device), (batch_size,))
        # print(f"get device of logdet2",logdet.get_device())

        # 前向运算
        if self.logdet:
            return self.scale * (input + self.loc), logdet
        else:
            return self.scale * (input + self.loc)

    def reverse(self, output):
        return output / self.scale - self.loc


class InvConv2d(nn.Module):
    def __init__(self, in_channel):
        super(InvConv2d, self).__init__()
        self.in_channel = in_channel
        weight = torch.randn(self.in_channel, self.in_channel)
        q, _ = torch.qr(weight)
        weight = q.unsqueeze(2).unsqueeze(3)
        self.weight = nn.Parameter(weight)

    def forward(self, input):
        batch_size, _, height, width = input.shape
        out = F.conv2d(input, self.weight)
        logdet = (
                height * width * torch.slogdet(self.weight.squeeze().double())[1].float()
        )
        return out, logdet

    def reverse(self, output):
        return F.conv2d(
            output, self.weight.squeeze().inverse().unsqueeze(2).unsqueeze(3)
        )


class InvConv2dLU(nn.Module):
    def __init__(self, in_channel):
        super(InvConv2dLU, self).__init__()
        weight = np.random.randn(in_channel, in_channel)

        # QR分解,Q为正交矩阵,R为上三角矩阵与0矩阵拼接起来的
        # 任何矩阵都可以进行QR分解
        q, _ = la.qr(weight)

        # 行列式不为0,则LU分解一定存在
        # Q为正交矩阵,行列式值为+1或-1,故一定可以进行LU分解
        # A = PLU,P为置换矩阵,L为下三角,U为上三角
        # 根据scipy教程,这里的L一定对角线元素为1
        w_p, w_l, w_u = la.lu(q.astype(np.float32))

        # 获取对角线元素构成数组,一维向量
        w_s = np.diag(w_u, 1)

        # 只保留第1条对角线的上三角矩阵,其实就是对角线元素变成了0
        w_u = np.triu(w_u, 1)

        # 只保留第1条对角线的上三角矩阵,其实就是对角线元素变成了0,而其他元素全为1
        u_mask = np.triu(np.ones_like(w_u), 1)

        # 只保留第-1条对角线的下三角矩阵,而其他元素全部为1
        l_mask = u_mask.T

        w_p = torch.from_numpy(w_p).to(device)
        w_l = torch.from_numpy(w_l).to(device)
        w_s = torch.from_numpy(w_s).to(device)
        w_u = torch.from_numpy(w_u).to(device)

        self.register_buffer("w_p", w_p)
        self.register_buffer("u_mask", torch.from_numpy(u_mask).to(device))
        self.register_buffer("l_mask", torch.from_numpy(l_mask).to(device))

        # sign,[-1,0,1] 三者中取其一
        self.register_buffer("s_sign", torch.sign(w_s))

        # 得到全为1为对角线的对角矩阵
        self.register_buffer("l_eye", torch.sign(w_s))

        self.w_l = nn.Parameter(w_l)
        self.w_s = nn.Parameter(logabs(w_s))
        self.w_u = nn.Parameter(w_u)

    def forward(self, input):
        batch_size, _, height, width = input.shape

        # 维度 为 [IC,IC,1,1]
        weight = self.calc_weight()
        out = F.conv2d(input, weight)

        # logdet是一个标量
        logdet = height * width * torch.sum(self.w_s)
        logdet = torch.tile(torch.tensor([logdet], device=device), (batch_size,))
        return out, logdet

    def calc_weight(self):
        weight = (
                self.w_p
                @ (self.w_l + self.l_mask + self.l_eye)  # 只所以这样做是因为w_l是随着训练的进行在变化
                @ ((self.w_u + self.u_mask) + torch.diag(self.s_sign + torch.exp(self.w_s)))
        )
        # 转成 4D tensor
        return weight.unsqueeze(2).unsqueeze(3)

    def forward(self, input):
        batch_size, _, height, width = input.shape
        # 维度为[IC,IC,1,1]
        weight = self.calc_weight()
        out = F.conv2d(input, weight)

        # logdet是一个标量
        logdet = height * width * torch.sum(self.w_s)
        logdet = torch.tile(torch.tensor([logdet], device=device), (batch_size,))
        return out, logdet

    def reverse(self, output):
        weight = self.calc_weight()
        print(torch.det(weight.squeeze()))
        return F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))


class ZeroConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, padding=1):
        super(ZeroConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, 3, padding=0)

        # 全零初始化
        self.conv.weight.data.zero_()
        self.conv.bias.data.zero_()

        # 形状为 [BS,OC,H,W],可以广播机制
        # 这里相当于是一个可以学习的channel scale,是对每个同奥进行单独缩放的
        # 也是全零初始化
        self.scale = nn.Parameter(torch.zeros(1, out_channel, 1, 1))

    def forward(self, input):
        out = F.pad(input, [1, 1, 1, 1], value=1)
        out = self.conv(out)
        out = out + torch.exp(self.scale + 3)
        return out


class AffineCoupling(nn.Module):
    def __init__(self, in_channel, filter_size=512, affine=True):
        super(AffineCoupling, self).__init__()
        self.affine = affine
        self.net = nn.Sequential(
            nn.Conv2d(in_channel // 2, filter_size, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(filter_size, filter_size, 1),
            nn.ReLU(inplace=True),
            # 全零初始化,这样网络一开始可以当做一个恒等变换
            # affine,是为了保证能输出两部分,分别是logs和t
            ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2), )
        self.net[0].weight.data.normal_(0, 0.05)
        self.net[0].bias.data.zero_()

        self.net[2].weight.data.normal_(0, 0.05)
        self.net[2].bias.data.zero_()

    def forward(self, input):
        # 在通道维度切分为a,b两份
        in_a, in_b = input.chunk(2, 1)

        if self.affine:
            # 经过网络输出log_s 和 t
            log_s, t = self.net(in_a).chunk(2, 1)
            # print(f"get device of log_s:",log_s.get_device())

            # s= torch.exp(log_s)
            # sigmoid 更稳定训练
            s = F.sigmoid(log_s + 2)

            # print(f"get device of s:",s.get_device())
            # out_a = s*in_a + t
            out_b = (in_b + t) * s

            # s的形状为[BS,ch//2,h,w]
            # 因为这里的s与batch_size有关,所以logdet的形状是[BS],而不再是变量
            logdet = torch.sum(torch.log(s).view(input.shape[0], -1), 1)
        else:
            net_out = self.net(in_a)
            out_b = in_b + net_out
            logdet = None
        # 因为 in_a = out_a
        # 最后再把 a,b 拼起来
        return torch.cat([in_a, out_b], 1), logdet

    def reverse(self, output):
        out_a, out_b = output.chunk(2, 1)
        if self.affine:
            # 逆运算
            log_s, t = self.net(out_a).chunk(2, 1)
            # s = torch.exp(log_s)
            s = F.sigmoid(log_s + 2)
            # in_a = (out_a-t)/s
            in_b = out_b / s - t
        else:
            net_out = self.net(out_a)
            in_b = out_b - net_out

        # 因为 in_a = out_a
        # 最后再把a,b拼起来
        return torch.cat([out_a, in_b], 1)


class Flow(nn.Module):
    def __init__(self, in_channel, affine=True, conv_lu=True):
        super(Flow, self).__init__()
        self.actnorm = ActNorm(in_channel)

        if conv_lu:
            self.invconv = InvConv2dLU(in_channel)
        else:
            self.invconv = InvConv2d(in_channel)
        self.coupling = AffineCoupling(in_channel, affine=affine)

    def forward(self, input):
        out, logdet = self.actnorm(input)
        # print(f"get device of logdet:",logdet.get_device())
        out, det1 = self.invconv(out)
        out, det2 = self.coupling(out)

        # print(f"get device of det1:",det1.get_device())
        # print(f"get device of det2:",det2.get_device())

        logdet = logdet + det1
        if det2 is not None:
            logdet = logdet + det2
        return out, logdet

    def reverse(self, output):
        input = self.coupling.reverse(output)
        input = self.invconv.reverse(input)
        input = self.actnorm.reverse(input)
        return input


class Block(nn.Module):
    """
    每个Block 包含多个Flow
    """

    def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True):
        super(Block, self).__init__()
        squeeze_dim = in_channel * 4
        self.flows = nn.ModuleList()
        for i in range(n_flow):
            self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))
        self.split = split

        if split:
            self.prior = ZeroConv2d(in_channel * 2, in_channel * 4)
        else:
            self.prior = ZeroConv2d(in_channel * 4, in_channel * 8)

        self.h_zero = nn.Parameter(torch.zeros(1, in_channel * 4, 8, 8), requires_grad=True)
        self.label_embedding = nn.Embedding(40, 32)
        self.proj_layer = nn.Linear(32, in_channel * 4)

    def forward(self, input, label):
        b_size, n_channel, height, width = input.shape

        # 对输入的通道数进行扩增,对空间进行缩小,4倍的比例
        squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2)
        squeezed = squeezed.permute(0, 1, 3, 5, 2, 4)
        out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2)
        logdet = 0

        for flow in self.flows:
            # 每个flow是各向同性的,输出形状与输入形状一致
            out, det = flow(out)
            logdet = logdet + det

        if self.split:
            out, z_new = out.chunk(2, 1)
            mean, log_sd = self.prior(out).chunk(2, 1)
            log_p = gaussian_log_p(z_new, mean, log_sd)
            log_p = log_p.view(b_size, -1).sum(1)
        else:
            condition = self.label_embedding(label).sum(1)  # [BS,dim]
            condition = F.softplus(self.proj_layer(condition))

            condition = condition.unsqueeze(-1).unsqueeze(-1)

            mean, log_sd = self.prior(self.h_zero + condition).chunk(2, 1)
            log_p = gaussian_log_p(out, mean, log_sd)
            log_p = log_p.view(b_size, -1).sum(1)
            z_new = out

        return out, logdet, log_p, z_new

    def reverse(self, output, label, eps=None, reconstruct=False):
        input = output

        if reconstruct:
            if self.split:
                input = torch.cat([output, eps], 1)
            else:
                input = eps
        else:
            if self.split:
                mean, log_sd = self.prior(input).chunk(2, 1)
                z = gaussian_sample(eps, mean, log_sd)
                input = torch.cat([output, z], 1)
            else:
                condition = self.label_embedding(label)  # [bs,dim]
                condition = F.softplus(self.proj_layer(condition))

                condition = condition.unsqueeze(-1).unsqueeze(-1)
                mean, log_sd = self.prior(self.h_zero + condition).chunk(2, 1)
                z = gaussian_sample(eps, mean, log_sd)
                input = z

        for flow in self.flows[::-1]:
            input = flow.reverse(input)

        b_size, n_channel, height, width = input.shape

        # 从 [bs,4*c,h/2,w/2]还原成[bs,c,h,w]
        unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)
        unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)
        unsqueezed = unsqueezed.contiguous().view(
            b_size, n_channel // 4, height * 2, width * 2
        )
        return unsqueezed


class Glow(nn.Module):
    def __init__(
            self, in_channel, n_flow, n_block, affine=True, conv_lu=True):
        super(Glow, self).__init__()

        self.blocks = nn.ModuleList()
        n_channel = in_channel

        for i in range(n_block - 1):
            self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))
            n_channel *= 2
        self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))

        if conv_lu:
            print("conv_lu=True")
        else:
            print("conv_lu=False")

        print("apply classifier_net for last latent z vector")
        self.classifier_net = nn.Sequential(
            nn.Flatten(1),
            nn.Linear(48 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 40),
        )

        def forward(self, input, label):
            # print(f"get device of input:",input.get_device())
            log_p_sum = 0
            logdet = 0
            out = input
            z_outs = []

            for block in self.blocks:
                out, det, log_p, z_new = block(out, label)
                z_outs.append(z_new)
                logdet = logdet + det

                if log_p is not None:
                    log_p_sum = log_p_sum + log_p

            logits = self.classifier_net(z_new)
            return log_p_sum, logdet, z_outs, logits

    def reverse(self, z_list, label, reconstruct=False):
        for i, block in enumerate(self.blocks[::-1]):
            if i == 0:
                # 最后一个block的z=out=input
                input = block.reverse(z_list[-1], label, z_list[-1], reconstruct=reconstruct)
            else:
                # 前面每个block的input=out等于输出
                input = block.reverse(input, label, z_list[-(i + 1)], reconstruct=reconstruct)
        return input
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值