论文速读|MoCo:Momentum Contrast for Unsupervised Visual Representation Learning.CVPR20

论文地址:Momentum Contrast for Unsupervised Visual Representation Learning
代码地址:https://github.com/facebookresearch/moco
bib引用:

@misc{he2020momentumcontrastunsupervisedvisual,
      title={Momentum Contrast for Unsupervised Visual Representation Learning}, 
      author={Kaiming He and Haoqi Fan and Yuxin Wu and Saining Xie and Ross Girshick},
      year={2020},
      eprint={1911.05722},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/1911.05722}, 
}

InShort

提出动量对比(MoCo)方法用于无监督视觉表征学习,通过构建动态字典解决相关问题,在多个视觉任务中取得优异成果,缩小了无监督和有监督表征学习之间的差距。【better than memory bank】

  1. 研究背景
    • 无监督表征学习现状:在自然语言处理中成果显著,如GPT和BERT,但在计算机视觉领域,监督式预训练仍占主导,无监督方法相对滞后。
    • 对比学习与字典构建:近期一些无监督视觉表征学习研究使用对比损失,可看作构建动态字典。本文认为理想的字典应具备大且一致的特点,并提出MoCo方法来构建这样的字典。
  2. 相关工作
    • 损失函数:常见的损失函数包括衡量模型预测与固定目标差异的函数,如L1、L2损失和交叉熵损失等。对比损失通过衡量样本对在表征空间的相似性来训练编码器,对抗损失则用于衡量概率分布差异。
    • pretext任务: pretext任务是为学习良好数据表征而设计的辅助任务,如恢复受损输入、基于单图像变换生成伪标签等,多种pretext任务可基于对比损失函数设计。
  3. 方法
    • 对比学习即字典查找:对比学习可视为训练编码器进行字典查找任务,InfoNCE是一种常用的对比损失函数,通过最小化该损失来训练编码器。
    • 动量对比(MoCo):将字典维护为数据样本队列,使字典大小与小批量大小解耦,可设置更大的字典。采用动量更新方式更新键编码器,使键编码器参数变化更平滑,保证字典中键的一致性。
    • pretext任务:采用简单的实例判别任务,将同一图像的不同随机增强视图作为正样本对,查询和键分别由不同编码器编码。同时采用了打乱BN(Batch Normalization)的方法解决模型“作弊”问题。
  4. 实验
    • 线性分类实验:在ImageNet-1M上进行无监督预训练,然后在冻结特征上训练线性分类器。对比了不同对比损失机制,结果表明MoCo在构建大字典和保持一致性方面具有优势,在相同模型规模下,MoCo的精度优于其他方法。
    • 特征迁移实验:将MoCo与ImageNet有监督预训练在多个下游任务中进行对比,包括PASCAL VOC和COCO的目标检测与分割、LVIS实例分割、Cityscapes实例分割和语义分割等任务。结果显示,MoCo在7个检测或分割任务中优于有监督预训练,在Cityscapes实例分割上与有监督预训练相当,在VOC语义分割上略逊一筹。
  5. 讨论与结论:MoCo在多种计算机视觉任务和数据集上取得了积极成果,但在利用大规模数据方面可能还有提升空间。未来可探索更先进的pretext任务,MoCo也有望在其他涉及对比学习的pretext任务中发挥作用。

摘要

我们提出了用于无监督视觉表示学习的 Momentum Contrast (MoCo)。从对比学习 [29] 作为字典查找的角度来看,我们构建了一个带有队列和移动平均编码器的动态字典。这样就可以动态构建一个大型且一致的字典,从而促进对比式无监督学习。MoCo 在 ImageNet 分类的通用线性协议下提供了有竞争力的结果。更重要的是,MoCo 学到的表示可以很好地转移到下游任务中。MoCo 可以在 PASCAL VOC、COCO 和其他数据集上的 7 项检测/分割任务中胜过其监督式预训练对手,有时甚至大大超过它。这表明,在许多视觉任务中,无监督和有监督的表征学习之间的差距已经基本缩小。

Introduction

无监督表示学习在自然语言处理中非常成功,例如,如 GPT [50, 51] 和 BERT [12] 所示。但有监督的预训练在计算机视觉中仍然占主导地位,而无监督方法通常落后。原因可能源于它们各自信号空间的差异。语言任务具有离散的信号空间(单词、子单词单元等),用于构建标记化词典,无监督学习可以基于这些词典。相比之下,计算机视觉进一步涉及字典构建 [54, 9, 5],因为原始信号位于连续的高维空间中,并且不是为人类交流而构建的(例如,与单词不同)。

最近的几项研究 [61, 46, 36, 66, 35, 56, 2] 在使用与对比损失相关的方法进行无监督视觉表征学习方面取得了有希望的结果 [29]。尽管这些方法的驱动动机多种多样,但可以将其视为构建动态词典。字典中的“键”(标记)是从数据(例如,图像或补丁)中采样的,并由编码器网络表示。无监督学习训练编码器执行字典查找:编码的 “query” 应与其匹配的 key 相似,而与其他 key 不同。学习被表述为最小化对比损失 [29]。

从这个角度来看,我们假设构建字典是可取的:(i) 大型和 (ii) 在训练过程中不断发展的一致性。直观地说,较大的字典可以更好地对底层连续的高维视觉空间进行采样,而字典中的键应该由相同或类似的编码器表示,以便它们与查询的比较是一致的。然而,使用对比损失的现有方法可能会在这两个方面之一受到限制(稍后将在上下文中讨论)。
在这里插入图片描述

图 1.Momentum Contrast (MoCo) 通过使用对比损失将编码的查询 q 与编码的键字典匹配来训练视觉表示编码器。字典键 k 0 , k 1 , k 2 , . . . {k_{0}, k_{1}, k_{2}, ...} k0,k1,k2,... 由一组数据样本动态定义。字典构建为队列,当前小批量排队,最早的小批量出队,使其与小批量大小分离。这些键由缓慢进展的编码器编码,由查询编码器的动量更新驱动。此方法启用了一个大型且一致的字典来学习视觉表示。

我们将动量对比 (MoCo) 作为一种为无监督学习构建大型且一致的词典的方法,但具有对比损失(图 1)。我们将字典维护为数据样本队列:当前小批量的编码表示将排队,而最早的表示将出列。队列将字典大小与小批量大小分离,使其较大。此外,由于字典键来自前面的几个小批量,因此提出了一个缓慢进展的键编码器,实现为查询编码器的基于动量的移动平均值,以保持一致性。

MoCo 是一种用于构建对比学习的动态词典的机制,可以与各种前置任务一起使用。在本文中,我们遵循一个简单的实例判别任务 [61, 63, 2]:如果它们是同一图像的编码视图(例如,不同的裁剪),则查询匹配一个键。使用这个借口任务,MoCo 在 ImageNet 数据集 [11] 中线性分类的通用协议下显示了有竞争力的结果。

无监督学习的一个主要目的是预先训练可以通过微调转移到下游任务的表示(即特征)。我们表明,在与检测或分割相关的 7 个下游任务中,MoCo 无监督预训练可以超越其 ImageNet 监督任务,在某些情况下,差距很大。在这些实验中,我们探索了在 ImageNet 或 10 亿个 Instagram 图像集上预先训练的 MoCo,证明 MoCo 可以在更真实、数十亿图像规模和相对未经策划的场景中很好地工作。这些结果表明,在许多计算机视觉任务中,MoCo 在很大程度上缩小了无监督和有监督表示学习之间的差距,并且可以在一些应用中作为 ImageNet 监督预训练的替代方案。

2. 相关工作

无监督/自我监督1 学习方法通常涉及两个方面:前置任务(pretext tasks)和损失函数。术语 “借口” 意味着所解决的任务不是真正的利益,而只是为了学习良好的数据表示的真正目的而解决。损失函数通常可以独立于 pretext 任务进行研究。MoCo 专注于损失函数方面。接下来,我们讨论这两个方面的相关研究。

2.1. 损失函数

损失函数。定义损失函数的一种常见方法是测量模型预测与固定目标之间的差异,例如通过 L1 或 L2 损失重建输入像素(例如,自动编码器),或通过交叉熵或基于边际的损失将输入分类为预定义的类别(例如,八个位置 [13],颜色区间 [64])。如下所述,其他替代方案也是可能的。

对比损失 [29] 衡量表示空间中样本对的相似性。在对比损失公式中,目标可以在训练过程中动态变化,并且可以根据网络计算的数据表示来定义,而不是将输入与固定目标匹配[29]。对比学习是最近几篇关于无监督学习的工作 [61, 46, 36, 66, 35, 56, 2] 的核心,我们将在后面的上下文中详细说明(第 3.1 节)。

对抗性损失 [24] 衡量概率分布之间的差异。这是一种广受欢迎的技术用于非监督式数据生成。表征学习的对抗方法在 [15, 16] 中进行了探讨。生成对抗网络和噪声对比估计 (NCE) 之间存在关系(参见 [24])[28]。

2.2. pretext task

前置任务。已经提出了各种各样的前置任务。示例包括在某些损坏的情况下恢复输入,例如,去噪自动编码器 [58]、上下文自动编码器 [48] 或跨通道自动编码器(着色)[64, 65]。一些前置任务通过例如单个(“示例”)图像的转换 [17]、补丁排序 [13, 45]、跟踪 [59] 或分割视频中的对象 [47] 或聚类特征 [3, 4] 来形成伪标签

对比学习与自监督学习中的前置任务(pretext tasks)之间的关系:
对比学习与前置任务的关联:在自监督学习里,前置任务是为了学习良好的数据表征而设计的辅助任务。许多前置任务会基于某种形式的对比损失函数来构建。这意味着对比损失函数在这些前置任务中起到关键作用,通过衡量样本对在表征空间的相似性,引导模型学习到有意义的特征表示。

补充1:对比学习与前置任务关联的具体示例

实例判别方法:实例判别方法与基于范例的任务以及噪声对比估计(NCE)相关。基于范例的任务是通过对单个 “范例” 图像进行变换来形成伪标签,实例判别方法在一定程度上借鉴了这种思路,同时也与NCE有联系,NCE本质是一种用于估计未归一化统计模型的方法,在实例判别中,可能会利用类似NCE的方式来构建对比损失,从而让模型学习到不同实例之间的差异。

  • 比如在一个图像数据集里,把同一图像的不同裁剪区域视为正样本对,不同图像的裁剪区域视为负样本对。模型的任务是学习如何区分这些样本对,通过最小化对比损失,让模型把来自同一图像的样本对的特征表示得更相似,而把来自不同图像的样本对的特征表示得更不同。像猫的图像,裁剪出猫的头部和身体部分作为正样本对,将猫的图像和狗的图像裁剪区域作为负样本对,训练模型识别同一物体不同部分的相似性和不同物体间的差异性。

对比预测编码(CPC):CPC中的前置任务本质上是一种上下文自动编码形式。它通过预测未来的上下文信息来学习特征表示,基于对比损失函数,让模型区分正确的上下文和错误的上下文,从而使模型能够捕捉到数据中的长期依赖关系,学习到更具代表性的特征。

  • 比如多模态/图文对比学习例子中,计算多模态对比损失,通过将预测的文本特征与实际文本特征和负样本特征进行对比来学习图像和文本的上下文关系。

对比多视图编码(CMC):CMC的前置任务与图像着色相关。在这个过程中,可能会利用图像不同视图之间的关系,基于对比损失函数,让模型学习到不同视图下图像特征的一致性和差异性,图像着色任务可能被用于提供不同的视图信息,辅助模型学习更鲁棒的视觉表示。

  • 对于图像任务,把同一图像经过不同颜色变换或几何变换后的视图作为不同视图。比如一张风景图像,一个视图是正常色彩和视角,另一个视图是经过色彩增强和轻微旋转后的图像。模型基于对比损失学习不同视图下图像特征的相似性和差异性,使模型在不同的图像变换下也能识别出是同一物体,提高模型对图像特征的鲁棒性 。

在这里插入图片描述

图 2.三种对比损失机制的概念比较(实证比较见图 3 和 表 3)。这里我们演示了一对 query 和 key。这三种机制在key的维护方式和密钥编码器的更新方式上有所不同。(a):用于计算查询和key 表示的encoder 通过反向传播进行端到端更新(两个编码器可以不同)。(b):关键表示是从内存库memory bank 中采样的 [61]。(c):MoCo 通过动量更新的编码器对新键进行动态编码,并维护一个键队列(图中未说明)。

3. 方法

3.1. 作为字典查找的对比学习

Contrastive learning [29], and its recent developments, can be thought of as training an encoder for a dictionary look-up task, as described next.

考虑一个编码查询 q 和一组编码样本 k 0 , k 1 , k 2 , . . . {k_{0}, k_{1}, k_{2}, ...} k0,k1,k2,...,它们是字典的键。假设字典中有一个键(表示为 k + k_{+} k+ ),该键匹配 q 。对比损失 [29] 是一个函数,当 q 与其正键 k + k_{+} k+ 相似,并且与所有其他键不同(被视为 q 的负键)时,其值较低。通过点积测量相似性,本文考虑了一种称为 InfoNCE [46] 的对比损失函数:
L q = − l o g e x p ( q ⋅ k + / τ ) ∑ i = 0 K e x p ( q ⋅ k i / τ ) \mathcal{L}_{q}=-log \frac{exp \left(q \cdot k_{+} / \tau\right)}{\sum_{i=0}^{K} exp \left(q \cdot k_{i} / \tau\right)} Lq=logi=0Kexp(qki/τ)exp(qk+/τ),其中 τ 是符合 [61] 的温度超参数。总和超过 1 个正样本和 K 个负样本。直观地说,这种损失是 ( K + 1 ) (K+1) (K+1) 方式基于 softmax 的分类器的对数损失,该分类器试图将 q 分类为 k + k_{+} k+。对比损失函数也可以基于其他形式 [29, 59, 61, 36],例如基于边际的损失和 NCE 损失的变体。

对比损失用作无监督目标函数,用于训练表示查询和键的编码器网络 [29]。通常,查询表示形式为 q = f q ( x q ) q=f_{q}(x^{q}) q=fq(xq),其中 f q f_{q} fq 是编码器网络, x 9 x^{9} x9 是查询样本(同样, k = f k ( x k ) ) k=f_{k}(x^{k})) k=fk(xk)) )。它们的实例化取决于特定的 pretext task。输入 x q x^{q} xq x k x^{k} xk 可以是图像 [29, 61, 63]、补丁 [46] 或由一组补丁组成的上下文 [46]。网络 f q f_{q} fq f k f_{k} fk 可以是相同的 [29, 59, 63]、部分共享的 [46, 36, 2] 或不同的 [56]。

3.2. 动量对比

从上述角度来看,对比学习是一种在高维连续输入(如图像)上构建离散字典的方法。字典是动态的,因为键是随机采样的,并且键编码器在训练期间不断发展。我们的假设是,一个好的特征可以通过一个包含大量负样本的大型字典来学习,而字典键的编码器尽管发生了变化,但尽可能地保持一致。基于这个动机,我们提出了 Momentum Contrast,如下所述。

3.2.1. Dictionary as a queue

我们方法的核心是将字典维护为数据样本队列。这允许我们重用前面的小批量中的编码密钥。队列的引入将字典大小与小批量大小分离。我们的字典大小可以比典型的小批量大小大得多,并且可以灵活且独立地设置为超参数。
字典中的样本将逐步替换。当前小批量将排入字典,并删除队列中最早的小批量。字典始终表示所有数据的采样子集,而维护此字典的额外计算是可管理的。此外,删除最旧的小批量可能是有益的,因为它的编码密钥是最过时的,因此与最新的密钥最不一致。

3.2.2. 动量更新

使用队列会使字典变大,但也会使通过反向传播来更新键编码器变得困难(渐变应传播到队列中的所有样本)。一种简单的解决方案是从查询编码器 f q f_{q} fq 中复制键编码器 f k f_{k} fk,忽略此梯度。但这种解决方案在实验中产生的结果很差(第 4.1 节)。我们假设这种失败是由快速变化的编码器引起的,这降低了key 表示的一致性。我们建议进行 momentum 更新来解决此问题。

正式地,将 f k f_{k} fk 的参数表示为 θ k \theta_{k} θk,将 f q f_{q} fq 的参数表示为 θ q \theta_{q} θq,我们通过以下方式更新 θ k \theta_{k} θk
θ k ← m θ k + ( 1 − m ) θ q . ( 2 ) \theta_{k} \leftarrow m \theta_{k}+(1-m) \theta_{q} . (2) θkmθk+(1m)θq.(2) 这里 m ∈ [ 0 , 1 ) m \in[0,1) m[0,1) 是一个动量系数。只有参数 θ q \theta_{q} θq 通过反向传播进行更新。方程(2) 中的动量更新使 θ k \theta_{k} θk θ q \theta_{q} θq 演化得更顺畅。因此,尽管队列中的键由不同的编码器编码(在不同的小批处理中),但这些编码器之间的差异可以很小。在实验中,相对较大的动量(例如, m = 0.999 m=0.999 m=0.999 1 我们的默认值)比较小的值(例如, m = 0.9 m=0.9 m=0.9)效果要好得多,这表明缓慢演变的键编码器是使用队列的核心。

与先前机制的关系。MoCo 是使用对比损失的通用机制。我们将其与图 2 中两个现有的通用机制进行了比较。它们在字典大小和一致性上表现出不同的属性。
通过反向传播进行端到端更新是一种自然机制(例如,[29, 46, 36, 63, 2, 35],图 2a)。它使用当前 mini-batch 中的样本作为字典,因此键的编码是一致的(由同一组编码器参数)。但是字典大小与小批量大小相耦合,受 GPU 内存大小限制。它也受到大型小批量优化的挑战 [25]。一些最近的方法 [46, 36, 2] 是基于由本地位置驱动的前置任务,其中字典的大小可以通过多个位置来变大。但是这些前置任务可能需要特殊的网络设计,例如修补输入[46]或自定义感受野大小[2],这可能会使这些网络向下游任务的转移复杂化。

另一种机制是 [61] 提出的memory bank 方法(图 2b)。内存库由数据集中所有样本的表示组成。每个小批量的字典都是从 memory bank 中随机采样的,没有反向传播,因此它可以支持较大的字典大小。但是,样本在Memory Bank 在最后一次看到时已更新,因此采样的键本质上是关于过去 epoch 中多个不同步骤的编码器的,因此不太一致。在 [61] 中,内存库采用了动量更新。它的动量更新是在同一样本的表示上,而不是编码器上。这种动量更新与我们的方法无关,因为 MoCo 不会跟踪每个样本。此外,我们的方法内存效率更高,并且可以在十亿级数据上进行训练,这对于内存库来说可能很困难。

在这里插入图片描述

3.3. Pretext Task

对比学习可以驱动各种前置任务。由于本文的重点不是设计一个新的前置任务,我们使用一个简单的任务,主要遵循[61]中的实例区分任务,最近的一些工作[63,2]与此相关。

按照[61],如果一个查询和一个键来自同一图像,我们将它们视为正样本对,否则视为负样本对。按照[63,2],我们在随机数据增强下对同一图像取两个随机“视图”以形成正样本对。查询和键分别由它们的编码器 f q f_{q} fq f k f_{k} fk编码。编码器可以是任何卷积神经网络[39]。

算法 1 提供了用于此自监督任务的 MoCo 伪代码。对于当前的小批次数据,对查询(queries)及其对应的键(keys)进行编码,它们形成正样本对。负样本来自队列。
采用 ResNet 作为编码器,其最后一个全连接层(在全局平均池化之后)具有固定维度的输出(128 维,参考文献[61])。这个输出向量通过其 L2 范数进行归一化,这是查询或键的表示形式。
公式(1)中的温度 T 设置为 0.07,参考文献[61]。数据增强设置参考[61],从随机调整大小的图像中截取 224×224 像素的裁剪区域,然后进行随机颜色抖动、随机水平翻转和随机灰度转换。

打乱批归一化(Shuffling BN)

编码器 f q f_{q} fq f k f_{k} fk都具有与标准 ResNet 中相同的批归一化(BN),参考文献[37]和[33]。在实验中发现,使用 BN 会阻止模型学习良好的表示,正如[35]中所报道的那样(该文献避免使用 BN)。模型似乎在“欺骗”自监督任务并容易找到低损失解。这可能是因为样本之间的批内通信(由 BN 引起)泄漏了信息。

我们通过打乱批归一化(BN)来解决这个问题。我们使用多个 GPU 进行训练,并在每个 GPU 上独立地对样本进行批归一化(这是常见的做法)。对于关键编码器 f k f_{k} fk,我们在将当前小批次中的样本分配到各个 GPU 之前打乱样本顺序(并且在编码后再恢复顺序);查询编码器 f q f_{q} fq的小批次样本顺序不做改变。这样可以确保用于计算查询及其正关键的批统计信息来自两个不同的子集。这有效地解决了作弊问题,并使训练能够从批归一化中受益。

我们在我们的方法及其端到端消融对应方法中都使用了随机打乱的批归一化(shuffled BN)(见图 2a)。它与记忆库对应方法(memory bank counterpart)(见图 2b)无关,记忆库对应方法不会受到这个问题的影响,因为正样本键(positive keys)来自过去不同的小批次数据。

5. 讨论和结论

我们的方法在各种计算机视觉任务和数据集中显示出无监督学习的积极结果。一些悬而未决的问题值得讨论。MoCo 从 IN-1M 到 IG-1B 的改进始终很明显,但相对较小,这表明更大规模的数据可能没有得到充分利用。我们希望高级 pretext 任务能够改善这一点。除了简单的实例区分任务 [61] 之外,还可以将 MoCo 用于伪装任务,例如在语言 [12] 和视觉 [46] 中进行自动编码掩码。我们希望 MoCo 对其他涉及对比学习的前置任务有用。

补充2:MoCo关键伪代码

# f_q, f_k: 分别为查询和键的编码器网络
# queue: 作为K个键的队列(CxK)表示字典
# m: 动量系数
# t: 温度超参数
import torch

# 初始化
m = 0.999  # 通常设置的动量系数值
t = 0.07  # 温度超参数值
f_k.params = f_q.params  # 初始化键编码器参数与查询编码器参数相同

for x in loader:  # 按批次加载数据
    # 对输入数据进行随机增强
    x_q = aug(x) 
    x_k = aug(x) 

    # 分别通过查询编码器和键编码器得到查询和键的特征表示
    q = f_q.forward(x_q)  
    k = f_k.forward(x_k)  
    k = k.detach()  # 不计算键的梯度

    # 计算正样本对数似然
    l_pos = torch.bmm(q.view(q.size(0), 1, -1), k.view(k.size(0), -1, 1))  

    # 计算负样本对数似然
    l_neg = torch.mm(q.view(q.size(0), -1), queue.view(-1, queue.size(1)))  

    # 合并正、负样本对数似然
    logits = torch.cat([l_pos, l_neg], dim=1)  

    # 定义标签(正样本为第0个)
    labels = torch.zeros(logits.size(0)).long()  

    # 计算对比损失(InfoNCE)
    loss = torch.nn.CrossEntropyLoss()(logits / t, labels)  

    # 使用随机梯度下降(SGD)更新查询网络
    loss.backward()  
    update(f_q.params)  

    # 动量更新键网络
    f_k.params = m * f_k.params + (1 - m) * f_q.params  

    # 更新字典队列
    enqueue(queue, k)  
    dequeue(queue)  

补充3:MoCo和Memory bank区别和伪代码

区别
Memory Bank
工作机制:使用一个大规模的缓冲区(memory bank)存储所有样本的特征向量。
在训练时,每次用当前批次的正样本特征作为正样本,其余从缓冲区中随机采样的特征作为负样本。
Memory Bank 中的特征通常是静态的,在一个或多个 epoch 后才会更新。

特点:①容纳的负样本数量非常大,通常等于整个训练集的样本数。②更新慢,容易导致特征陈旧(outdated)。③适合小规模数据集,但在大规模数据集上效率较低。

MoCo
工作机制:使用一个动量更新的队列(momentum-updated queue)代替静态的 Memory Bank。
动量机制使得队列中的负样本特征会随着模型训练逐步更新,保持相对较新的特性。
通过两个编码器(主编码器和动量编码器),动量编码器更新缓慢,生成稳定的负样本。
特点:①动态更新负样本,保持特征的时效性。②队列长度可控,内存占用较少。③更适合大规模数据集和高效训练。

总结
Memory Bank更简单,但可能会导致负样本特征陈旧,尤其是数据量大时
MoCo通过动量机制动态更新负样本特征,特征新鲜度更高,性能通常更优。

代码示例:参考ALBEF

基于MoCo的版本:https://github.com/salesforce/ALBEF/blob/main/models/model_pretrain.py

from functools import partial
from models.vit import VisionTransformer, interpolate_pos_embed
from models.xbert import BertConfig, BertForMaskedLM

class ALBEF(nn.Module):
    def __init__(self,text_encoder = None,
                 tokenizer = None,
                 config = None,    
                 temp = 0.07,
                 init_deit = True
                 ):
        super().__init__()
        
        self.tokenizer = tokenizer 
        self.mlm_probability = config['mlm_probability']
        embed_dim = config['embed_dim']
     
        self.visual_encoder = VisionTransformer(
            img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, 
            mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))   
        
        if init_deit:
            checkpoint = torch.hub.load_state_dict_from_url(
                url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
                map_location="cpu", check_hash=True)
            state_dict = checkpoint["model"]
            pos_embed_reshaped = interpolate_pos_embed(state_dict['pos_embed'], self.visual_encoder)
            state_dict['pos_embed'] = pos_embed_reshaped
            msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
            print(msg)          
            
        vision_width = config['vision_width']       
        bert_config = BertConfig.from_json_file(config['bert_config'])
        
        self.text_encoder = BertForMaskedLM.from_pretrained(text_encoder, config=bert_config)      

        text_width = self.text_encoder.config.hidden_size
        self.vision_proj = nn.Linear(vision_width, embed_dim)
        self.text_proj = nn.Linear(text_width, embed_dim)         

        self.temp = nn.Parameter(torch.ones([]) * config['temp'])   
        self.queue_size = config['queue_size']
        self.momentum = config['momentum']  
        self.itm_head = nn.Linear(text_width, 2)     

        # create momentum models
        self.visual_encoder_m = VisionTransformer(
            img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, 
            mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) 
        self.vision_proj_m = nn.Linear(vision_width, embed_dim)
        self.text_encoder_m = BertForMaskedLM.from_pretrained(text_encoder, config=bert_config)       
        self.text_proj_m = nn.Linear(text_width, embed_dim)    
        
        self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
                            [self.vision_proj,self.vision_proj_m],
                            [self.text_encoder,self.text_encoder_m],
                            [self.text_proj,self.text_proj_m],
                           ]
        
        self.copy_params()

        # create the queue
        self.register_buffer("image_queue", torch.randn(embed_dim, self.queue_size))
        self.register_buffer("text_queue", torch.randn(embed_dim, self.queue_size))
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))  
                             
        self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
        self.text_queue = nn.functional.normalize(self.text_queue, dim=0)

    def forward(self, image, text, alpha=0):
        with torch.no_grad():
            self.temp.clamp_(0.001,0.5)
        
        image_embeds = self.visual_encoder(image) 
        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)

        image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)  

        text_output = self.text_encoder.bert(text.input_ids, attention_mask = text.attention_mask,                      
                                        return_dict = True, mode = 'text')            
        text_embeds = text_output.last_hidden_state
        text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)                 
             
        # get momentum features
        with torch.no_grad():
            self._momentum_update()
            image_embeds_m = self.visual_encoder_m(image) 
            image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)  
            image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)                                         
            text_output_m = self.text_encoder_m.bert(text.input_ids, attention_mask = text.attention_mask,                      
                                                return_dict = True, mode = 'text')    
            text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 
            text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)

            sim_i2t_m = image_feat_m @ text_feat_all / self.temp 
            sim_t2i_m = text_feat_m @ image_feat_all / self.temp     

            sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
            sim_targets.fill_diagonal_(1)          

            sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
            sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets        

        sim_i2t = image_feat @ text_feat_all / self.temp 
        sim_t2i = text_feat @ image_feat_all / self.temp 
                             
        loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
        loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 

        loss_ita = (loss_i2t+loss_t2i)/2

        self._dequeue_and_enqueue(image_feat_m, text_feat_m)

        ###=================================###
        # forward the positve image-text pair
        output_pos = self.text_encoder.bert(encoder_embeds = text_embeds, 
                                        attention_mask = text.attention_mask,
                                        encoder_hidden_states = image_embeds,
                                        encoder_attention_mask = image_atts,      
                                        return_dict = True,
                                        mode = 'fusion',
                                       )            
        with torch.no_grad():
            bs = image.size(0)          
            weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)
            weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)
   
            weights_i2t.fill_diagonal_(0)
            weights_t2i.fill_diagonal_(0)

        # select a negative image for each text
        image_embeds_neg = []    
        for b in range(bs):
            neg_idx = torch.multinomial(weights_t2i[b], 1).item()
            image_embeds_neg.append(image_embeds[neg_idx])
        image_embeds_neg = torch.stack(image_embeds_neg,dim=0)   

        # select a negative text for each image
        text_embeds_neg = []
        text_atts_neg = []
        for b in range(bs):
            neg_idx = torch.multinomial(weights_i2t[b], 1).item()
            text_embeds_neg.append(text_embeds[neg_idx])
            text_atts_neg.append(text.attention_mask[neg_idx])
        text_embeds_neg = torch.stack(text_embeds_neg,dim=0)   
        text_atts_neg = torch.stack(text_atts_neg,dim=0)      

        text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0)     
        text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)     

        image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
        image_atts_all = torch.cat([image_atts,image_atts],dim=0)

        output_neg = self.text_encoder.bert(encoder_embeds = text_embeds_all, 
                                        attention_mask = text_atts_all,
                                        encoder_hidden_states = image_embeds_all,
                                        encoder_attention_mask = image_atts_all,      
                                        return_dict = True,
                                        mode = 'fusion',
                                       )                         

        vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
        vl_output = self.itm_head(vl_embeddings)            

        itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
                               dim=0).to(image.device)
        loss_itm = F.cross_entropy(vl_output, itm_labels)     
        
        ##================= MLM ========================##                
        input_ids = text.input_ids.clone()
        labels = input_ids.clone()

        probability_matrix = torch.full(labels.shape, self.mlm_probability)                    
        input_ids, labels = self.mask(input_ids, self.text_encoder.config.vocab_size, image.device, targets=labels,probability_matrix = probability_matrix) 
        
        with torch.no_grad():
            logits_m = self.text_encoder_m(input_ids, 
                                           attention_mask = text.attention_mask,
                                           encoder_hidden_states = image_embeds_m,
                                           encoder_attention_mask = image_atts,      
                                           return_dict = True,
                                           return_logits = True,   
                                          )    
        mlm_output = self.text_encoder(input_ids, 
                                       attention_mask = text.attention_mask,
                                       encoder_hidden_states = image_embeds,
                                       encoder_attention_mask = image_atts,      
                                       return_dict = True,
                                       labels = labels,   
                                       soft_labels = F.softmax(logits_m,dim=-1),
                                       alpha = alpha
                                      )                           
        loss_mlm = mlm_output.loss        

        return loss_mlm, loss_ita, loss_itm  

    @torch.no_grad()    
    def copy_params(self):
        for model_pair in self.model_pairs:           
            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
                param_m.data.copy_(param.data)  # initialize
                param_m.requires_grad = False  # not update by gradient    
            
    @torch.no_grad()        
    def _momentum_update(self):
        for model_pair in self.model_pairs:           
            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
                param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
            
    @torch.no_grad()
    def _dequeue_and_enqueue(self, image_feat, text_feat):
        # gather keys before updating queue
        image_feats = concat_all_gather(image_feat)
        text_feats = concat_all_gather(text_feat)

        batch_size = image_feats.shape[0]

        ptr = int(self.queue_ptr)
        assert self.queue_size % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
        self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
        ptr = (ptr + batch_size) % self.queue_size  # move pointer

        self.queue_ptr[0] = ptr 
		...

在 ALBEF 的基础上引入 memory bank,

Memory Bank 的关键思路:

  1. 特征存储: 使用一个更大的队列或数组来存储特征,并与当前 batch 的特征进行对比。
  2. 更新方式: 用 momentum 或替换策略更新 memory bank,确保特征动态变化且始终保持多样性。
  3. 替代现有的队列:image_queuetext_queue 替换为全局 memory bank,允许更灵活的存储和采样策略。

Step 1: Memory Bank 初始化
image_queuetext_queue 替换为一个全局的 memory_bank_imagememory_bank_text

class ALBEF(nn.Module):
    def __init__(self,                 
                 text_encoder=None,
                 tokenizer=None,
                 config=None,    
                 temp=0.07,
                 init_deit=True
                 ):
        super().__init__()

        # 省略初始化的部分代码...

        # Memory Bank for image and text
        self.memory_bank_image = torch.randn(config['memory_bank_size'], config['embed_dim'])
        self.memory_bank_text = torch.randn(config['memory_bank_size'], config['embed_dim'])
        self.memory_bank_image = nn.functional.normalize(self.memory_bank_image, dim=1)
        self.memory_bank_text = nn.functional.normalize(self.memory_bank_text, dim=1)

        # Momentum update parameters
        self.momentum = config['momentum']
        self.memory_bank_size = config['memory_bank_size']
        self.memory_ptr = 0

Step 2: 更新 Memory Bank
_dequeue_and_enqueue 方法中,替换队列的更新逻辑,改为对 memory bank 进行替换。

    @torch.no_grad()
    def _update_memory_bank(self, image_feat, text_feat):
        """
        Update memory bank with current batch features.
        """
        batch_size = image_feat.size(0)

        # Update image memory bank
        self.memory_bank_image[self.memory_ptr:self.memory_ptr + batch_size] = image_feat
        # Update text memory bank
        self.memory_bank_text[self.memory_ptr:self.memory_ptr + batch_size] = text_feat

        # Normalize
        self.memory_bank_image = nn.functional.normalize(self.memory_bank_image, dim=1)
        self.memory_bank_text = nn.functional.normalize(self.memory_bank_text, dim=1)

        # Update pointer
        self.memory_ptr = (self.memory_ptr + batch_size) % self.memory_bank_size

Step 3: 在 Momentum 特征计算中加入 Memory Bank
forward 中,调整 text_feat_allimage_feat_all 的生成方式,将 memory bank 的特征加入到对比中。

        with torch.no_grad():
            # Update momentum features
            self._momentum_update()
            image_embeds_m = self.visual_encoder_m(image)
            image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1)
            
            text_output_m = self.text_encoder_m.bert(text.input_ids, 
                                                     attention_mask=text.attention_mask, 
                                                     return_dict=True, mode='text')
            text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:, 0, :]), dim=-1)

            # Combine with memory bank features
            image_feat_all = torch.cat([image_feat_m, self.memory_bank_image], dim=0)
            text_feat_all = torch.cat([text_feat_m, self.memory_bank_text], dim=0)

            # Compute similarity scores
            sim_i2t_m = image_feat_m @ text_feat_all.T / self.temp
            sim_t2i_m = text_feat_m @ image_feat_all.T / self.temp

Step 4: 计算损失并更新 Memory Bank
确保在完成对比学习损失后,更新 memory bank。

        # Contrastive loss
        sim_i2t = image_feat @ text_feat_all.T / self.temp
        sim_t2i = text_feat @ image_feat_all.T / self.temp

        loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean()
        loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean()
        loss_ita = (loss_i2t + loss_t2i) / 2

        # Update memory bank
        self._update_memory_bank(image_feat, text_feat)

Step 5: 配置文件参数调整
在配置文件中,添加以下两个新参数:

config = {
    ...
    'memory_bank_size': 65536,  # Memory bank size
    'momentum': 0.999          # Momentum for updating features
}

以上是 ALBEF 的 memory bank 版本,主要体现在以下方面:

  1. memory_bank_imagememory_bank_text 提供特征存储。
  2. _update_memory_bank 方法动态更新 memory bank 。

进一步调整 memory bank 的使用策略:使用最近的 K 个样本代替全部样本。

补充4:MoCov2

MoCov2论文地址:Improved Baselines with Momentum Contrastive Learning

通过在Momentum Contrast(MoCo)框架中引入SimCLR的两种设计改进(MLP投影头和更强的数据增强),建立了更强的无监督学习基线,并证明这些改进在不需要大批次训练的情况下能够超越SimCLR。

1. 研究背景

文章聚焦于对比学习(contrastive learning),这是一种通过学习数据对的相似性/不相似性来生成表示的方法。MoCo和SimCLR是两种对比学习方法,MoCo通过动量编码器和队列机制处理大量负样本,而SimCLR依赖于大批次训练。本文的目标是通过改进MoCo框架,使其在不依赖大批次训练的情况下实现更好的性能。

2. 方法

文章在MoCo框架中引入了SimCLR的两个关键改进:

  • MLP投影头:将MoCo中原有的全连接(fc)投影头替换为两层MLP,隐藏层维度为2048,激活函数为ReLU。
  • 更强的数据增强:引入高斯模糊和更强的颜色失真。

此外,文章还研究了不同的温度参数τ对性能的影响,并采用余弦学习率调度来优化训练过程。

3. 实验

实验在1.28M ImageNet训练集上进行无监督学习,并通过以下两种协议评估性能:

  • ImageNet线性分类:冻结特征,训练监督线性分类器,报告单作物验证集的top-1准确率。
  • VOC目标检测:在VOC 07+12训练集上微调Faster R-CNN检测器,并使用COCO评估指标。
关键结果:
  • MLP投影头:使用MLP头后,ImageNet线性分类准确率从60.6%提升到62.9%(τ=0.07),进一步调整τ到0.2后,准确率达到66.2%。
  • 更强的数据增强:仅使用额外的高斯模糊和更强的颜色失真,ImageNet准确率提升到63.4%,与MLP头结合后,准确率进一步提升到67.3%。
  • 与SimCLR对比:MoCo v2在200个epoch和256批次大小下达到67.5%的准确率,比SimCLR在相同设置下高出5.6%,并且超过了SimCLR在4096批次大小下的66.6%。

4. 计算成本

文章还比较了MoCo和SimCLR的计算成本:

  • MoCo在256批次大小下,内存占用为5.0G,训练时间为53小时。
  • SimCLR在相同批次大小下,内存占用为7.4G,训练时间为65小时。
  • SimCLR在4096批次大小下,内存占用高达93.0G,难以在普通8-GPU机器上运行。

5. 结论

文章通过在MoCo框架中引入MLP投影头和更强的数据增强,建立了更强的无监督学习基线(MoCo v2)。这些改进不仅提升了性能,还降低了对大批次训练的依赖,使得最先进的无监督学习方法更加易于实现。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值