BEV:隐式相机视角转换-----BEVFormer

一、背景

基于imp投影的相机视角转换,对相机的内外参依赖较高,BEV 网格融合固定,可能对小目标不够敏感;考虑通过transformer的方式进行相机的视角转换,BEV query 可以自适应关注关键区域,提高小目标检测,transformer 注意力机制,灵活采样。故通过BEVFormer的demo代码理解其原理。

二、代码

import torch
import torch.nn as nn

# -------------------------
# 参数
# -------------------------
B, C, H, W = 2, 64, 16, 16        # 摄像头特征图
bev_H, bev_W = 8, 8                # BEV 网格
num_cameras = 7
num_classes = 10
num_det_queries = 32                # detection query 数量

# -------------------------
# 1. 多摄像头特征
# -------------------------
camera_feats = [torch.randn(B, C, H*W) for _ in range(num_cameras)]  # B x C x N (N=H*W)
for i in range(num_cameras):
    camera_feats[i] = camera_feats[i].permute(0, 2, 1)  # B x N x C

# -------------------------
# 2. BEV query + Transformer 投影
# -------------------------
num_bev_queries = bev_H * bev_W
bev_queries = nn.Parameter(torch.randn(num_bev_queries, B, C))

class BEVProjectionTransformer(nn.Module):
    def __init__(self, C, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=C, num_heads=num_heads)

    def forward(self, bev_queries, camera_feats):
        """
        bev_queries: num_bev_queries x B x C
        camera_feats: list of B x N x C
        """
        # 拼接所有摄像头特征
        feats = torch.cat(camera_feats, dim=1)      # B x (num_cameras*N) x C
        feats = feats.permute(1,0,2)               # (num_cameras*N) x B x C
        bev_out, _ = self.attn(bev_queries, feats, feats)
        return bev_out

bev_proj_transformer = BEVProjectionTransformer(C)
bev_features = bev_proj_transformer(bev_queries, camera_feats)
bev_features_grid = bev_features.permute(1,0,2).reshape(B, bev_H, bev_W, C)

# -------------------------
# 3. Detection query + Transformer
# -------------------------
det_queries = nn.Parameter(torch.randn(num_det_queries, B, C))

class DetectionDecoderTransformer(nn.Module):
    def __init__(self, C, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=C, num_heads=num_heads)

    def forward(self, det_queries, bev_features_grid):
        B, H, W, C = bev_features_grid.shape
        bev_flat = bev_features_grid.reshape(B, H*W, C).permute(1,0,2)
        out, _ = self.attn(det_queries, bev_flat, bev_flat)
        return out

decoder = DetectionDecoderTransformer(C)
det_features = decoder(det_queries, bev_features_grid)

# -------------------------
# 4. Detection head
# -------------------------
class SimpleDetectionHead(nn.Module):
    def __init__(self, C, num_classes):
        super().__init__()
        self.cls_head = nn.Linear(C, num_classes)
        self.bbox_head = nn.Linear(C, 7)

    def forward(self, det_features):
        cls_logits = self.cls_head(det_features)
        bbox_preds = self.bbox_head(det_features)
        return cls_logits, bbox_preds

detection_head = SimpleDetectionHead(C, num_classes)
cls_logits, bbox_preds = detection_head(det_features)

print("类别 logits shape:", cls_logits.shape)     # num_det_queries x B x num_classes
print("3D bbox preds shape:", bbox_preds.shape)   # num_det_queries x B x 7

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值