TransUNet:基于 Transformer 和 CNN 的混合编码网络

TransUNet是将Transformer结构引入到医学图像分割的UNet模型中,通过混合CNN和Transformer编码器,利用两者优势提高分割效果。实验表明,TransUNet在Synapse多器官分割和ACDC数据集上表现出色,其结构包括CNN下采样、Transformer编码和CNN解码。消融实验验证了跳跃连接和序列长度对模型精度的影响。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Visual Transformer

Author:louwill

Machine Learning Lab

    

在深度学习医学图像分割领域,UNet结构一直以来都牢牢占据着主导地位。自从2015年提出U形结构以来,后续在UNet基础上做出的魔改网络不可计数。Tranformer结构逐渐开始用于视觉领域之后,基于UNet和Tranformer结合的相关结构和研究逐渐兴起。

UNet用了这么多年,效果好是毋庸置疑的。但硬要是找一些缺点,也不是找不到。由于CNN的平移不变性和捕捉长期依赖能力的不足,UNet在一定程度上依然有较大的提升空间。而Tranformer正好以捕捉序列之间的长期依赖而见长,将Tranformer结构融入到以CNN为主体的UNet中,能否进一步发挥UNet的威力呢?

答案是肯定的。今天我们要介绍的网络叫做TransUNet,正是一种充分结合UNet和Tranformer这两种结构的医学图像分割模型。提出TransUNet的论文为TransUNet:Transformers make strong encoders for medical image segmentation,发表于2021年2月,由约翰霍普金斯大学和电子科技大学等学校联合提出。

TransUNet结构

TransUNet完整结构如图1所示。

其中图(a)是一层Transformer结构示意图,图(b)是完整的TransUNet架构。Transformer结构不多说,对于图像块嵌入后,行常规的Layer Norm+MSA+MLP+残差连接结构处理。

我们重点看一下图(b)的TransUNet完整架构。完整的结构仍然是U形的编解码结构。先来看编码器部分,这也是TransUNet的关键部分。编码器部分先是对输入图像做了三层卷积下采样,对CNN得到的特征图进行图像块嵌入,同样也是要加位置编码,然后将块嵌入后的一维向量输入到12层Transformer结构中。所以TransUNet编码器的策略是CNN和Transformer混合构建编码器。这也是论文题目中make strong encoders的含义所在。

为什么要混合编码呢?这也是为了各自利用Transformer和CNN的优点来考虑的。Transformer更在注重全局信息,但容易忽略低分辨率下的图像细节,这对于解码器恢复像素尺寸伤害比较大,会导致分割结果很粗糙。而CNN正好可以弥补Transformer的这个缺点。所以混合编码在作者看来是大有裨益的。

然后是解码器,解码器比较简单,就是常规的转置卷积上采样恢复图像像素。同时从编码器的CNN下采样对应过来同层分辨率的级联。这些都属于原始的UNet的固有操作。

TransUNet实验

作者分别在Synapse多器官分割数据集和ACDC (自动化心脏诊断挑战赛)上实验了TransUNet的效果。具体地,对于混合编码器,论文中使用ResNet-50和ViT分别作为CNN和Transformer的backbone,并且都经过了ImageNet的预训练处理。

表1是TransUNet与VNet等模型的效果对比。


除了直接的模型精度比对之外,论文中还做了大量的消融实验研究。TransUNet的消融实验主要包括四个方面:1)跳跃连接数,2)输入图像分辨率,3)序列长度和图像分块大小,4)模型大小。

下面我们仅从第一个和第三个方面来看一下TransUNet的消融实验。第一个方面是尝试不同的跳跃连接数来观测模型分割的dice精度。对TransUNet网络分别不做添加、添加1和3条跳跃连接后的实验对比效果如图2所示。


实验结果也再一次强化了跳跃连接对于U形结构分割网络的强大效果。

消融实验的第三个方面是关于图像分块大小和序列长度对于模型精度影响的。当然这两个说的是一回事,图像分块尺寸越小,图像分块数量就越多,也就是序列越长。一般认为,patch size越小,Transformer序列越长,就越能编码出更为复杂的依赖关系。论文中分别实验了32、16和8三个尺寸的patch size,实验效果如表2所示。


图3显示了TransUNet、R50-ViT-CUP、AttentionUNet和UNet四个模型在多器官分割数据上的可视化效果。从视觉效果上的对比来看,TransUNet无疑是跟Ground Truth最为接近的了。


TransUNet代码实现

TransUNet完整代码实现可参考论文作者提供的仓库:

https://github.com/Beckschen/TransUNet

按照图1的模型架构,TransUNet最后的搭建代码如下所示。

class TransUNet(nn.Module):
        def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
            super(VisionTransformer, self).__init__()
            self.num_classes = num_classes
            self.zero_head = zero_head
            self.classifier = config.classifier
            self.transformer = Transformer(config, img_size, vis)
            self.decoder = DecoderCup(config)
            self.segmentation_head = SegmentationHead(
                in_channels=config['decoder_channels'][-1],
                out_channels=config['n_classes'],


                kernel_size=3,
            )
            self.config = config


        def forward(self, x):
            if x.size()[1] == 1:
                x = x.repeat(1,3,1,1)
            x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
            x = self.decoder(x, features)
            logits = self.segmentation_head(x)
            return logits

总结

TransUNet是率先将Transformer结构用于医学图像分割工作的研究。TransUNet将重视全局信息的Transformer结构和底层图像特征的CNN一起进行混合编码,能够更大程度上提升UNet的分割效果。

参考资料:
Chen J, Lu Y, Yu Q, et al. Transunet: Transformers make strong encoders for medical image segmentation[J]. arXiv preprint arXiv:2102.04306, 2021.
往期精彩:
 SETR:基于视觉 Transformer 的语义分割模型

 ViT:视觉Transformer backbone网络ViT论文与代码详解

【原创首发】机器学习公式推导与代码实现30讲.pdf
【原创首发】深度学习语义分割理论与实战指南.pdf
求个在看
### TransUNet 模型架构组成部分 #### 1. Transformer编码器部分 TransUNet采用基于ViT (Vision Transformers) 的编码器作为骨干网络,用于提取图像的全局特征表示。输入图像被划分为多个不重叠的小块(patch),这些patch通过线性投影转换成向量序列并送入多层Transformer编码器中处理[^1]。 ```python class ViT(nn.Module): def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12): super().__init__() self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) self.blocks = nn.ModuleList([ Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4., qkv_bias=True, drop_path=dpr[i]) for i in range(depth)]) ``` #### 2. 解码器部分 解码器由一系列卷积模块构成,负责逐步恢复空间分辨率并将低层次的空间细节信息与高层次语义信息相结合。为了更好地融合不同尺度的信息,在跳跃连接处采用了双线性插值方法来进行特征图尺寸匹配[^2]。 ```python def forward(self, x): features = [] # Encoder path for layer in self.encoder_layers: x = layer(x) features.append(x) # Decoder path with skip connections for idx, decoder_layer in enumerate(self.decoder_layers[::-1]): x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) x = torch.cat([x, features[-(idx + 2)]], dim=1) x = decoder_layer(x) ``` #### 3. 跳跃连接(Skip Connection) 类似于传统U-Net的设计理念,TransUNet同样保留了从浅层到深层再到浅层的信息传递路径。具体来说就是将下采样过程中丢失掉的一些边缘轮廓等细粒度特性重新加入到最后预测阶段之前的位置上去,从而提高最终输出的质量。 #### 4. 多尺度监督(Multi-scale Supervision) 除了最顶层外,其他各层也可以施加额外的损失函数项来指导训练过程;这样做不仅有助于加速收敛速度而且还能进一步提升泛化性能。这种策略特别适用于医学影像分析任务当中因为往往存在标注样本稀缺的问题。
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值