一、背景
基于imp投影的相机视角转换,对相机的内外参依赖较高,BEV 网格融合固定,可能对小目标不够敏感;考虑通过transformer的方式进行相机的视角转换,BEV query 可以自适应关注关键区域,提高小目标检测,transformer 注意力机制,灵活采样。故通过BEVFormer的demo代码理解其原理。
二、代码
import torch
import torch.nn as nn
# -------------------------
# 参数
# -------------------------
B, C, H, W = 2, 64, 16, 16 # 摄像头特征图
bev_H, bev_W = 8, 8 # BEV 网格
num_cameras = 7
num_classes = 10
num_det_queries = 32 # detection query 数量
# -------------------------
# 1. 多摄像头特征
# -------------------------
camera_feats = [torch.randn(B, C, H*W) for _ in range(num_cameras)] # B x C x N (N=H*W)
for i in range(num_cameras):
camera_feats[i] = camera_feats[i].permute(0, 2, 1) # B x N x C
# -------------------------
# 2. BEV query + Transformer 投影
# -------------------------
num_bev_queries = bev_H * bev_W
bev_queries = nn.Parameter(torch.randn(num_bev_queries, B, C))
class BEVProjectionTransformer(nn.Module):
def __init__(self, C, num_heads=8):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=C, num_heads=num_heads)
def forward(self, bev_queries, camera_feats):
"""
bev_queries: num_bev_queries x B x C
camera_feats: list of B x N x C
"""
# 拼接所有摄像头特征
feats = torch.cat(camera_feats, dim=1) # B x (num_cameras*N) x C
feats = feats.permute(1,0,2) # (num_cameras*N) x B x C
bev_out, _ = self.attn(bev_queries, feats, feats)
return bev_out
bev_proj_transformer = BEVProjectionTransformer(C)
bev_features = bev_proj_transformer(bev_queries, camera_feats)
bev_features_grid = bev_features.permute(1,0,2).reshape(B, bev_H, bev_W, C)
# -------------------------
# 3. Detection query + Transformer
# -------------------------
det_queries = nn.Parameter(torch.randn(num_det_queries, B, C))
class DetectionDecoderTransformer(nn.Module):
def __init__(self, C, num_heads=8):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=C, num_heads=num_heads)
def forward(self, det_queries, bev_features_grid):
B, H, W, C = bev_features_grid.shape
bev_flat = bev_features_grid.reshape(B, H*W, C).permute(1,0,2)
out, _ = self.attn(det_queries, bev_flat, bev_flat)
return out
decoder = DetectionDecoderTransformer(C)
det_features = decoder(det_queries, bev_features_grid)
# -------------------------
# 4. Detection head
# -------------------------
class SimpleDetectionHead(nn.Module):
def __init__(self, C, num_classes):
super().__init__()
self.cls_head = nn.Linear(C, num_classes)
self.bbox_head = nn.Linear(C, 7)
def forward(self, det_features):
cls_logits = self.cls_head(det_features)
bbox_preds = self.bbox_head(det_features)
return cls_logits, bbox_preds
detection_head = SimpleDetectionHead(C, num_classes)
cls_logits, bbox_preds = detection_head(det_features)
print("类别 logits shape:", cls_logits.shape) # num_det_queries x B x num_classes
print("3D bbox preds shape:", bbox_preds.shape) # num_det_queries x B x 7