__all__ = (
"DFL",
"HGBlock",
"HGStem",
"SPP",
"SPPF",
"C1",
"C2",
"C3",
"C2f",
"C2fAttn",
"ImagePoolingAttn",
"ContrastiveHead",
"BNContrastiveHead",
"C3x",
"C3TR",
"C3Ghost",
"GhostBottleneck",
"Bottleneck",
"BottleneckCSP",
"Proto",
"RepC3",
"ResNetLayer",
"RepNCSPELAN4",
"ELAN1",
"ADown",
"AConv",
"SPPELAN",
"CBFuse",
"CBLinear",
"C3k2",
"C2fPSA",
"C2PSA",
"RepVGGDW",
"CIB",
"C2fCIB",
"Attention",
"PSA",
"SCDown",
"TorchVision",
"S2DP_C3K",
)
__all__ = (
"Conv",
"Conv2",
"LightConv",
"DWConv",
"DWConvTranspose2d",
"ConvTranspose",
"Focus",
"GhostConv",
"ChannelAttention",
"SpatialAttention",
"CBAM",
"Concat",
"RepConv",
"Index",
)
上面的是已经有的模块声明,请你据此来进行修改完善一下下面的代码,将整理好的代码写出来:
class BaseModule(nn.Module):
"""基础模块:处理LayerNorm和残差连接"""
def __init__(self, embed_dim):
super().__init__()
self.norm = nn.LayerNorm(embed_dim) # 层归一化层
def forward_with_residual(self, x, sublayer):
"""残差连接:输入x与子层处理后的结果相加"""
return x + sublayer(self.norm(x)) # 先归一化,再通过子层,最后残差连接
class LinearAttention(BaseModule):
def __init__(self, embed_dim, num_heads):
super().__init__(embed_dim)
self.attention = nn.MultiheadAttention(embed_dim, num_heads)
def forward(self, x):
attn_output, _ = self.attention(x, x, x)
return self.forward_with_residual(x, lambda _: attn_output)
class DepthwiseSeparableFFN(BaseModule):
def __init__(self, embed_dim, expansion_ratio=4):
super().__init__(embed_dim)
expanded_dim = embed_dim * expansion_ratio
self.ffn = nn.Sequential(
nn.Linear(embed_dim, expanded_dim),
nn.GELU(),
nn.Linear(expanded_dim, embed_dim)
)
def forward(self, x):
return self.forward_with_residual(x, self.ffn)
class EfficientTransformerLayer(nn.Module):
def __init__(self, embed_dim, num_heads, expansion_ratio=4):
super().__init__()
self.attention = LinearAttention(embed_dim, num_heads) # 注意力子层
self.ffn = DepthwiseSeparableFFN(embed_dim, expansion_ratio) # 前馈子层
def forward(self, x):
x = self.attention(x) # 先通过注意力层
x = self.ffn(x) # 再通过前馈网络层
return x
class EfficientTransformer(nn.Module):
def __init__(self, num_layers, embed_dim, num_heads, expansion_ratio=4):
super().__init__()
self.layers = nn.ModuleList([
EfficientTransformerLayer(embed_dim, num_heads, expansion_ratio)
for _ in range(num_layers)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
最新发布