注意力机制代码实现
时间: 2025-05-18 17:04:13 AIGC 浏览: 35
### 注意力机制的代码实现
以下是基于PyTorch和TensorFlow两种框架下注意力机制的简单实现:
#### 基于PyTorch的SENet注意力机制实现
```python
import torch
import torch.nn as nn
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
# 测试代码
if __name__ == "__main__":
input_tensor = torch.randn((8, 512, 7, 7)) # Batch size 8, channels 512, spatial dimensions 7x7
se_layer = SELayer(channel=512, reduction=16)
output_tensor = se_layer(input_tensor)
print(output_tensor.shape) # 输出应为 (8, 512, 7, 7)
```
此代码实现了SENet中的通道注意力机制[^4]。
#### 基于TensorFlow的多头自注意力机制实现
```python
import tensorflow as tf
from tensorflow.keras.layers import Dense, LayerNormalization, Dropout
def scaled_dot_product_attention(q, k, v, mask=None):
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
output = tf.matmul(attention_weights, v)
return output, attention_weights
class MultiHeadAttention(tf.keras.Model):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = Dense(d_model)
self.wk = Dense(d_model)
self.wv = Dense(d_model)
self.dense = Dense(d_model)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, q, k, v, mask=None):
batch_size = tf.shape(q)[0]
q = self.wq(q)
k = self.wk(k)
v = self.wv(v)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))
output = self.dense(concat_attention)
return output, attention_weights
# 测试代码
if __name__ == "__main__":
temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = tf.random.uniform((1, 60, 512)) # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, k=y, v=y, mask=None)
print(out.shape) # 输出应为 (1, 60, 512)
print(attn.shape) # 输出应为 (1, 8, 60, 60)
```
上述代码展示了如何利用TensorFlow实现一个多头自注意力机制,这是Transformer架构的核心部分之一[^3]。
### 关于深度学习框架的选择
PyTorch因其动态图特性,在调试阶段提供了极大的便利性,允许开发者像编写常规Python脚本那样逐步执行并检查中间变量的状态[^1]。与此同时,Hugging Face Transformers库也支持PyTorch作为其主要后端之一,这进一步增强了PyTorch生态系统的吸引力[^2]。相比之下,TensorFlow虽然以静态图为传统优势,但在版本更新过程中逐渐引入了Eager Execution模式,从而缩小了与PyTorch之间的差距。
阅读全文
相关推荐


















