ViT:第一篇完全transformer CV分类
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale ICLR2021
Embedded Patches (197 x 768)
= Patch Embedding (196 x 768) +(concat) CLS Embedding (1 x 768) +(add) Position Embedding (197 x 768)
1.1 预处理
(1)原始输入
输入图像尺寸224x224x3,patch size大小为16x16,则有14 x 14=196个patch
H x W x C=N x(PxPxC) N为patch num,P为patch size
N = H x W/(P x P)=196
模型输入:H x W x C=N x(PxPxC)=196x768=NxD
即每个patch拉直后维度为768=PxPxC=16x16x3
(2)Liner projection of flattened patchs,即符号映射E
flattened patchs:196x768
E:fc层 大小是768→768
即,FC输入维度196 x 768→FC输出维度196 x 768
P是不固定的,那么得到的的每个patch的向量长度也是不一样的,为了模型不受patch size 大小的影响,作者引入了Linear Projection of Flattened Patches来把每个 P2⋅C 维的patch大小线性映射成固定的D维
(3)Extra CLS Embedding
额外加一个CLS Embedding(CLS token),大小为1 x 768
196x768→197 x 768,(N+1)xD concat
添加一个专门用于分类的的[class]Token,因为该Token没有予以信息,可以能好的反映图像表征;
a. CLS token输出作为整张图像用来分类的特征
b. CNN中,最后的feature map做GAP(全局平均池化)作为图像特征
为什么这么设置呢?不在所有patch的输出上使用GAP作为图像特征?
消融实验,二者效果差不多,为了和原始Transformer一致,继续使用[class] token的输出作为整张图像的特征
BETR设置?
最左边的 ∗ 称为learnable embedding,是一个与其他patch相同数据格式的可学习参数,目的是为了最后一步从该参数中获得分类信息。至于为什么要额外添加这个参数的目的是防止人为的指定导致图像分类结果偏向其中一个patch而造成误差。
(4)Position Embedding
加入位置信息矩阵,position embedding,大小为(N+1)xD,即197 x 768
197x768→197 x 768,(N+1)xD sum
1D和2D,一维和二维位置编码,消融实验结果差不多,继续使用1D位置编码
预处理:
1.2 transformer encoder
transformer encoder由L个transformer block组成
(1)embedded patchs
Embedded Patches (197 x 768)
= Patch Embedding (196 x 768) +(concat) CLS Embedding (1 x 768) +(add) Position Embedding (197 x 768)
encoder输入尺寸:(196+1)x 768=197 x 768
(2)MSA(multi-head self-attention)
12个head,每个head的Q、K、V维度为197 x (768/12)=197x64
每个head的结果拼接concat后输出为197x768
类似CNN中使用多个滤波器,有助于网络捕捉更丰富特征
(3)MLP
MLP:先放大,再缩小维度
197 x 768→197 x 3072→197 x 768
1.3 MLP Head output
CLS embedding的输出作为整个transformer encoder的输出,即整张图像的特征1x768维
196x768→1x768
但是我们在最初的想法就是在[class]token中保存着特征信息以便分类,所以在MLP Head过程中,我们将输出的shape(197,768)进行切片成需要的分类shape(1,768)。
1.4 结论
在小数据集(ImageNet)上预训练,ViT效果低于 BiT ResNet
在大数据集(JFT-300M)上预训练,ViT效果高于BiT ResNet
使用ViT,至少需要与ImageNet-21k相当的数据集规模,否则还是CNN效果更好
(1)embedding filter
embedding filter可视化,捕捉到了纹理颜色等信息,和CNN差不多
embedding filter是什么?
(2)position embedding
Position Embedding Similarity,1D位置编码,仍然学到了2D位置信息,所以消融实验二者效果差不多
位置编码在每个位置与其他位置的余弦相似度可以看出来,距离越近的图像块更相似,同一行/列的图像块相似度比较大
(3)attention distance 同CNN感受野大小
Mean attention distance,MSA每个head能注意到多远的距离(同CNN感受野大小)
浅层有的注意很近的距离,有的可以注意很远的距离,随着网络的深入,每个head都可以注意很远的距离
(4)self-supervision
仿照BERT,进行masked patch prediction,ViT-B/16在ImageNet比有监督低,效果不佳
1.5 微调
先去掉了预先训练的预测头,并附加一个零初始化的D*K前馈层,其中K是下游类的数量。
(1)高分辨率,P不变,N增加,导致position embedding不足
维度D=PxPxC不变,N太大
解决:2D插值
patch embedding : 1 2 3 4 5
pos:1,2,3,4,5
patch embedding plus : 1 2 3 4 5 6 7 8 9 10
pos:1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5.5
1.6 归纳偏差