论文阅读——MAT: Mask-Aware Transformer for Large Hole Image Inpainting

提出MAT,一种用于大型图像修复的Transformer框架,能够直接处理高分辨率图像。包括卷积头、改进的Transformer模块及风格操作模块,采用多头上下文注意力机制提升修复效果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原文链接:

2022

CVPR 2022

MAT: Mask-Aware Transformer for Large Hole Image Inpainting [pdf] [code]

本文创新点:

  1. 开发了一种新颖的修复框架 MAT,是第一个能够直接处理高分辨率图像的基于 transformer 的修复系统。
  2. 提出了一种新的多头自注意力 (MSA) 变体,称为多头上下文注意力 (MCA),只使用有效的token来计算注意力。
  3. 设计了一个风格操作模块,使模型能够通过调节卷积的权重来提供不同的预测结果。

网络结构

网络分为粗修复与细修复两个阶段。粗修复主要由一个卷积头,五个transformer模块和一个卷积尾构成;细修复采用一个 Conv-U-Net 来细化高频细节。

Convolutional Head

卷积头主要由四个卷积层构成,将3*512*512的图像转换成180*64*64的特征图,用来提取token。

Transformer Body

本文对transformer模块进行了改进,一是删除了层归一化,二是将残差连接改成了全连接层。

删除层归一化的原因:在大面积区域缺失的情况下,大部分的token是无效的,而层归一化会放大这些无效的token,从而导致训练不稳定;

删除残差连接的原因:残差连接鼓励模型学习高频内容,然而在刚开始大多数的token是无效的,在训练过程中没有适当的低频基础,很难直接学习高频细节,如果使用残差连接就会使优化变得困难。

Multi-Head Contextual Attention

注意力模块利用移位窗口和动态掩码,只使用有效的token进行加权求和,

其中,表达式如下:

其中,为100。通过加上掩码,无效的token经过softmax后的权重几乎等于0。每次计算注意力后,将w*w大小的窗口的位置移动 (⌊  ⌋, ⌊  ⌋) 个位置,从而实现信息交互。

Mask Updating Strategy

更新规则:只要当前窗口有一个token是有效的,经过注意力后,该窗口中的所有token都会更新为有效的。如果一个窗口中的所有token都是无效的,经过注意力后,它们仍然无效。

Style Manipulation Module

它通过在带有额外噪声输入的重建过程中改变卷积层的权重归一化来操纵输出。为了增强噪声输入的表示能力,我们强制图像条件风格从图像特征X 和噪声无条件风格中学习,

其中,B为随机二值掩码(值为1的概率为p,为0的概率为1− p),εF都为映射函数,最终的风格是融合两种风格得到的

其中,A为映射函数,则卷积的权重W更新为

其中,i,j,k分别为输入通道,输出通道,卷积核的大小,ε为很小的常数。

损失函数

Adversarial Loss

Perceptual Loss.

Overall Loss

### 关于 Mask-Aware Transformer (MAT) 的权重、参数设置、模型训练及实现方法 #### 权重初始化与预训练策略 对于 MAT 中的权重初始化,遵循标准做法来确保网络能够稳定收敛。通常情况下,会使用 Xavier 或 He 初始化方式来设定初始权重值[^2]。 ```python import torch.nn as nn def init_weights(m): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) model.apply(init_weights) ``` #### 参数配置与优化器选择 在 MAT 训练过程中,超参数的选择至关重要。学习率一般设为较小数值(如 $1e^{-4}$),并采用 AdamW 作为优化算法以加速收敛过程。此外,还设置了 weight decay 和 gradient clipping 技术防止过拟合以及梯度爆炸现象的发生。 ```python optimizer = torch.optim.AdamW( model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.01 ) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) ``` #### 动态掩码机制下的损失函数构建 为了更好地适应大范围缺失区域内的像素预测任务,MAT 设计了特殊的损失函数结构。除了常规 L1/L2 损失外,加入了感知损失(perceptual loss),风格损失(style loss)等多模态约束条件,从而提升生成图片的质量和真实性。 ```python criterion_l1 = nn.L1Loss() criterion_perceptual = PerceptualLoss() # 自定义或第三方库中的类 criterion_style = StyleLoss() total_loss = criterion_l1(output, target) + \ perceptual_weight * criterion_perceptual(output, target) + \ style_weight * criterion_style(output, target) ``` #### 高效推理阶段的技术手段 针对高分辨率输入带来的计算压力问题,MAT 提出了分块处理(batch processing)的方法论。即先将整张图分割成若干个小窗口(window),再分别送入到 transformer 层内完成特征提取工作;最后通过拼接操作恢复原始尺寸大小的结果图像。 ```python patch_size = 64 stride = patch_size // 2 patches = extract_patches(image_tensor, patch_size, stride=stride) outputs = [] for patch in patches: output_patch = mat_model(patch.unsqueeze(0)) outputs.append(output_patch.squeeze()) reconstructed_image = reconstruct_from_patches(outputs, original_shape=image.shape) ```
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值