P425
def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,
pos_threshold=0.009999999):
conf, class_id = torch.max(cls_prob[1:], 0)
1. 输入 cls_prob
的结构
cls_prob
是cls_probs[i]
,即第i
个样本的分类概率,形状为(num_classes + 1, num_anchors)
。num_classes + 1
:包括所有目标类别 + 1个背景类(假设背景类是第 0 类)。num_anchors
:锚框的数量。
例如,如果有 20 个类别 + 1 个背景类,cls_prob.shape = (21, 1000)
(假设 1000 个锚框)。
2. cls_prob[1:]
的作用
cls_prob[1:]
表示 去掉背景类(第 0 类),只保留目标类别的概率。- 如果原始
cls_prob
是(21, 1000)
,那么cls_prob[1:]
就是(20, 1000)
。 - 这样做的目的是 只关心目标类别的置信度,忽略背景类。
- 如果原始
3. torch.max(cls_prob[1:], 0)
的作用
torch.max(..., dim=0)
表示 在第 0 维度(类别维度)上取最大值。- 返回值:
conf
:每个锚框的最大类别概率(置信度),形状(num_anchors,)
。class_id
:每个锚框对应的类别索引(从0
开始,对应cls_prob[1:]
的第 0 类,即原始cls_prob
的第 1 类)。
class_id = class_id[all_id_sorted]
conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]
这一行代码的作用是重新排序
conf[below_min_idx] = 1 - conf[below_min_idx]
这行代码 conf[below_min_idx] = 1 - conf[below_min_idx]
的作用是 对低置信度的预测框(被判定为背景的框)的置信度进行特殊处理。我们来详细解析它的逻辑和目的:
1. 上下文背景
在目标检测的后处理阶段,代码会先通过 pos_threshold
过滤掉低置信度的预测框:
below_min_idx = (conf < pos_threshold) # 找出置信度低于阈值的框
class_id[below_min_idx] = -1 # 将这些框的类别标记为背景(-1)
此时:
class_id = -1
表示该预测框是背景(无效检测)。- 但
conf
仍然存储原始的目标类别概率(可能很低,比如0.01
)。
2. conf[below_min_idx] = 1 - conf[below_min_idx]
的作用
这行代码对低置信度框的 conf
值做了一个反向操作:
- 原始
conf
:表示模型对“最可能目标类别”的预测概率(例如0.01
)。 - 修改后
conf
:变为1 - conf
(例如1 - 0.01 = 0.99
)。
non_keep
的conf
不用做上面的操作是因为non_keep
本身是通过nms
搞出来的,本来的conf
不没有很低