channel attention、position attention、spatial attention
时间: 2025-05-09 09:46:48 浏览: 45
### 不同类型的注意力机制概念及其在深度学习中的应用
#### 通道注意力(Channel Attention)
通道注意力关注的是特征图各个通道之间的依赖关系。通过计算每个通道的重要性权重来增强有用的特征并抑制无用的信息。这种机制能够帮助模型聚焦于最具判别力的特征。
一种常见的实现方式是Squeeze-and-Excitation Networks (SENet)[^1],它引入了一个轻量级模块用于动态调整各卷积层输出特征图上的响应。具体来说,在SE模块内部会先全局池化得到描述整个图像上下文信息的一维向量;接着经过两个全连接层映射到与输入相同维度的空间中,并施加sigmoid激活函数获得0~1之间数值作为最终权值系数乘回原特征图上相应位置处。
```python
import torch.nn as nn
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
```
#### 位置注意力(Position Attention)
位置注意力侧重捕捉序列数据中元素间相对距离的影响。对于自然语言处理任务而言尤为重要,因为词语顺序决定了语义表达的不同含义。自回归变换器架构Transformer就是基于此原理构建而成,其核心在于多头自注意机制(Multi-head Self-Attention Mechanism),该方法允许网络在同一时间步长内考虑来自不同子空间的位置关联性[^2]。
在一个标准的Transformer编码单元里,Q、K、V矩阵分别代表查询(Query)、键(Key)以及值(Value),它们由线性投影后的词嵌入或者前一层隐藏状态生成。当计算当前时刻t的状态表示时,不仅要看自己本身的内容,还要综合考量其他所有时间节点上传递给自己的影响程度大小:
\[ \text{Attention}(Q,K,V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+M\right)V \]
其中\( M \)是一个掩码矩阵用来屏蔽掉未来未发生事件的作用效果。
#### 空间注意力(Spatial Attention)
空间注意力则更注重二维平面上像素点间的相互作用规律。这类技术常被应用于视觉领域比如目标检测、分割等问题解决过程中。不同于传统固定窗口滑动扫描策略,采用灵活可变的感受野范围可以更好地适应复杂场景变化需求。例如CBAM(CoordAtt: Enhancing Convolution with Coordinate Attention)就在基础之上进一步提出了坐标轴方向上的细化操作——即除了常规意义上的宽高尺寸外还额外增加了水平垂直两组一维分布概率密度估计过程以期达到更加精准定位目的.
```python
def spatial_attention_module(input_tensor):
batch_size, channels, height, width = input_tensor.shape
# Compute average pooling across all channels at each spatial location.
avg_pooled = F.adaptive_avg_pool2d(input_tensor, output_size=(height, width))
# Compute max pooling similarly.
max_pooled = F.adaptive_max_pool2d(input_tensor, output_size=(height, width))
concat = torch.cat([avg_pooled, max_pooled], dim=1)
conv_output = conv(concat) # Assume 'conv' is defined elsewhere.
sigmoid_activated = torch.sigmoid(conv_output)
attended_feature_map = input_tensor * sigmoid_activated
return attended_feature_map
```
阅读全文
相关推荐
















