本文介绍Detect_NMS、OBB_NMS、Polygon_NMS、Pose_NMS、Seg_NMS、Class_NMS等等
后续会添加更多任务的NMS
- Pose_NMS和Seg_NMS只需要对检测边框进行NMS,关键点和掩码无需NMS
- Class_NMS无需NMS,仅排序
一、Detect_NMS
Detect_NMS是目标检测后处理。
-
C++代码如下:
struct Object
{
cv::Rect_<float> rect;
int label;
float prob;
};
//=====Detect_IOU=====//
static inline float intersection_area(const Object& a, const Object& b)
{
cv::Rect_<float> inter = a.rect & b.rect;
return inter.area();
}
//=====Detect_NMS=====//
static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
{
picked.clear();
const int n = objects.size();
std::vector<float> areas(n);
for (int i = 0; i < n; i++)
{
areas[i] = objects[i].rect.area();
}
for (int i = 0; i < n; i++)
{
const Object& a = objects[i];
int keep = 1;
for (int j = 0; j < (int)picked.size(); j++)
{
const Object& b = objects[picked[j]];
if (!agnostic && a.label != b.label)
continue;
// intersection over union
float inter_area = intersection_area(a, b);
float union_area = areas[i] + areas[picked[j]] - inter_area;
// float IoU = inter_area / union_area
if (inter_area / union_area > nms_threshold)
keep = 0;
}
if (keep)
picked.push_back(i);
}
}
-
C#代码如下:
public class DetectionResult
{
public int ClassId{ get; set; }
public string Class{ get; set; }
public Rect Rect{ get; set; } // 多边形的四个顶点
public float Confidence{ get; set; }
}
//=====Detect_IOU=====//
private static float CalculateIoU(Rect rect1, Rect rect2)
{
// 计算相交区域
int x1 = Math.Max(rect1.X, rect2.X);
int y1 = Math.Max(rect1.Y, rect2.Y);
int x2 = Math.Min(rect1.X + rect1.Width, rect2.X + rect2.Width);
int y2 = Math.Min(rect1.Y + rect1.Height, rect2.Y + rect2.Height);
// 计算相交面积
int interArea = Math.Max(0, x2 - x1) * Math.Max(0, y2 - y1);
// 计算各自面积
float area1 = rect1.Width * rect1.Height;
float area2 = rect2.Width * rect2.Height;
// 计算并集面积
float unionArea = area1 + area2 - interArea;
// 避免除以零
return unionArea > float.Epsilon ? interArea / unionArea : 0f;
}
//=====Detect_NMS=====//
public static List<DetectionResult> ApplyNMS(List<DetectionResult> detections, float iouThreshold = 0.5f, float confidenceThreshold = 0.5f)
{
// 1. 过滤低置信度检测框
var validDetections = detections.Where(d = > d.Confidence >= confidenceThreshold).ToList();
// 2. 按类别分组
var classGroups = validDetections.GroupBy(d = > d.ClassId).ToDictionary(g = > g.Key, g = > g.ToList());
List<DetectionResult> finalResults = new List<DetectionResult>();
// 3. 对每个类别单独处理
foreach(var group in classGroups.Values)
{
// 按置信度降序排序
group.Sort((a, b) = > b.Confidence.CompareTo(a.Confidence));
while (group.Count > 0)
{
// 取出当前最高置信度的检测框
var current = group[0];
finalResults.Add(current);
group.RemoveAt(0);
// 计算与剩余框的IoU并移除重叠过高的框
for (int i = group.Count - 1; i >= 0; i--)
{
float iou = CalculateIoU(current.Rect, group[i].Rect);
if (iou > iouThreshold)
{
group.RemoveAt(i);
}
}
}
}
return finalResults;
}
或者:
CvDnn.NMSBoxes(results.Select(x => x.Rect), results.Select(x => x.Confidence), scoreThreshold: 0.5f, nmsThreshold: 0.5f, out int[] indices);
二、OBB_NMS
OBB_NMS是旋转目标检测后处理
C++代码
struct Object
{
cv::RotatedRect rrect;
int label;
float prob;
};
//=====OBB_IOU====//
static inline float intersection_area(const Object& a, const Object& b)
{
std::vector<cv::Point2f> intersection;
cv::rotatedRectangleIntersection(a.rrect, b.rrect, intersection);
if (intersection.empty())
return 0.f;
return cv::contourArea(intersection);
}
//=====OBB_NMS=====//
static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
{
picked.clear();
const int n = objects.size();
std::vector<float> areas(n);
for (int i = 0; i < n; i++)
{
areas[i] = objects[i].rrect.size.area();
}
for (int i = 0; i < n; i++)
{
const Object& a = objects[i];
int keep = 1;
for (int j = 0; j < (int)picked.size(); j++)
{
const Object& b = objects[picked[j]];
if (!agnostic && a.label != b.label)
continue;
// intersection over union
float inter_area = intersection_area(a, b);
float union_area = areas[i] + areas[picked[j]] - inter_area;
// float IoU = inter_area / union_area;
if (inter_area / union_area > nms_threshold)
keep = 0;
}
if (keep)
picked.push_back(i);
}
}
C#代码
public class OBBResult
{
public int ClassId{ get; set; }
public string Class{ get; set; }
public RotatedRect RotaRect{ get; set; } ////RotatedRect格式为(new Point2f(centerX, centerY)、new Size2f(width, height)、旋转角度)
public float Confidence{ get; set; }
}
//=====RotatedIOU=======//
private static float CalculateRotatedIoU(RotatedRect rect1, RotatedRect rect2)
{
// 获取两个旋转矩形的四个顶点
Point2f[] points1 = rect1.Points();
Point2f[] points2 = rect2.Points();
// 计算相交区域面积(直接内联实现)
float intersectionArea;
try
{
Point2f[] intersectionPoints;
intersectionArea = Cv2.IntersectConvexConvex(points1, points2, out intersectionPoints);
}
catch
{
intersectionArea = 0f; // 如果计算出错,返回0面积
}
// 计算并集面积
float area1 = rect1.Size.Width * rect1.Size.Height;
float area2 = rect2.Size.Width * rect2.Size.Height;
float unionArea = area1 + area2 - intersectionArea;
// 避免除以零的情况
return unionArea > 0 ? intersectionArea / unionArea : 0f;
}
//=====OBB_NMS=======//
public static List<OBBResult> ApplyNMS(List<OBBResult> detections, float iouThreshold = 0.5f, float confidenceThreshold = 0.5f)
{
// 1. 过滤低置信度检测框
var validDetections = detections.Where(d = > d.Confidence >= confidenceThreshold).ToList();
// 2. 按类别分组
var classGroups = validDetections
.GroupBy(d = > d.ClassId)
.ToDictionary(g = > g.Key, g = > g.ToList());
List<OBBResult> finalResults = new List<OBBResult>();
// 3. 对每个类别单独处理
foreach(var group in classGroups.Values)
{
// 按置信度降序排序
group.Sort((a, b) = > b.Confidence.CompareTo(a.Confidence));
// 当前类别的处理结果
List<OBBResult> classResults = new List<OBBResult>();
while (group.Count > 0)
{
// 取出当前最高置信度的检测框
var current = group[0];
classResults.Add(current);
group.RemoveAt(0);
// 计算与剩余框的IoU并移除重叠过高的框
for (int i = group.Count - 1; i >= 0; i--)
{
float iou = CalculateRotatedIoU(current.RotaRect, group[i].RotaRect);
if (iou > iouThreshold)
{
group.RemoveAt(i);
}
}
}
finalResults.AddRange(classResults);
}
return finalResults;
}
三、Polygon_NMS
多边形检测NMS
C++代码
struct polygon
{
std::vector<cv::Point2f> boundingBoxPoints; //储存多边形的顶点
float confidence; // 置信度
int classIndex; // 类别索引
};
//=====Polygon_IOU=====//
float PolygonIoU(const std::vector<cv::Point2f>& PolyBox1, const std::vector<cv::Point2f>& PolyBox2)
{
if (PolyBox1.size() != 4 || PolyBox2.size() != 4) {
return 0.0f;
}
// 创建 Mat 对象来表示多边形的轮廓
std::vector<cv::Point> contour1, contour2;
for (const auto& p : PolyBox1) contour1.push_back(cv::Point(p.x, p.y));
for (const auto& p : PolyBox2) contour2.push_back(cv::Point(p.x, p.y));
// 计算两个凸多边形的交集
std::vector<cv::Point> intersection;
cv::intersectConvexConvex(contour1, contour2, intersection); // OpenCV 函数
// 计算交集的面积
float intersectionArea = (intersection.size() > 0) ? cv::contourArea(intersection) : 0.0f;
// 计算两个多边形的面积
float area1 = cv::contourArea(contour1);
float area2 = cv::contourArea(contour2);
// 计算并返回 IoU
float unionArea = area1 + area2 - intersectionArea;
return intersectionArea / (unionArea + 1e-6f); // 防止除以零
}
//=====Polygon_NMS======//
static void PolygonNonMaxSuppression(const std::vector<polygon>& results, std::vector<int>& picked, float iouThreshold)
{
picked.clear();
int n = results.size();
if (n > 300) { n = 300; }
// 遍历所有对象
for (int i = 0; i < n; i++) {
const polygon& a = results[i];
int keep = 1;
for (int j = 0; j < (int)picked.size(); j++)
{
const polygon& b = results[picked[j]];
// 如果标签不同且非标签无关模式,则跳过
if (a.classIndex != b.classIndex)
continue;
// 计算多边形的 IoU
float iou = PolygonIoU(a.boundingBoxPoints, b.boundingBoxPoints); //确保相同类别的边框进行nms
// 如果 IoU 超过阈值,则不保留
if (iou >= iouThreshold)
{
keep = 0;
//break; // 提前退出循环
}
}
if (keep)
{
picked.push_back(i); // 保留该框
//break; //保留第一个置信度最大的结果(需要确保只有一个类别,不然只会输出第一个结果)
}
}
}
C#代码
public class DetectionResult
{
public Point2f[] BoundingBoxPoints{ get; set; } // 多边形的四个顶点
public double Confidence{ get; set; }
public int ClassId{ get; set; }
}
//=====Polygon_IOU=====//
public static float PolygonIoU(Point2f[] box1, Point2f[] box2)
{
if (box1 == null || box2 == null || box1.Length != 4 || box2.Length != 4)
return 0;
// 创建正确的 Mat,确保其类型为 CV_32FC2(二维浮点型点)
//Mat contour1 = new Mat(box1.Length, 1, MatType.CV_32FC2, box1.Select(p => new Vec2f(p.X, p.Y)).ToArray());
//Mat contour2 = new Mat(box2.Length, 1, MatType.CV_32FC2, box2.Select(p => new Vec2f(p.X, p.Y)).ToArray());
Mat contour1 = new Mat(box1.Length, 1, MatType.CV_32FC2);
for (int i = 0; i < box1.Length; i++)
{
contour1.Set(i, 0, new Vec2f(box1[i].X, box1[i].Y));
}
Mat contour2 = new Mat(box2.Length, 1, MatType.CV_32FC2);
for (int i = 0; i < box2.Length; i++)
{
contour2.Set<Vec2f>(i, new Vec2f(box2[i].X, box2[i].Y));
}
Mat intersection = new Mat();
Cv2.IntersectConvexConvex(contour1, contour2, intersection); //计算两个凸多边形的交集
// 检查是否有交集
float intersectionArea = intersection.Rows > 0 ? (float)Cv2.ContourArea(intersection) : 0.0f;
float area1 = (float)Cv2.ContourArea(contour1);
float area2 = (float)Cv2.ContourArea(contour2);
// 计算并返回 IoU
float unionArea = area1 + area2 - intersectionArea;
return intersectionArea / (unionArea + 1e-6f); // 防止除以零
}
//=====Polygon_NMS=====//
public static YOLOV5PolygonClass.DetectionResult[] PolygonNonMaxSuppression(YOLOV5PolygonClass.DetectionResult[] results, double iouThreshold)
{
// 按类别分组
var groupedByClass = new Dictionary<int, List<YOLOV5PolygonClass.DetectionResult>>();
foreach(var result in results)
{
if (!groupedByClass.ContainsKey(result.ClassId))
{
groupedByClass[result.ClassId] = new List<YOLOV5PolygonClass.DetectionResult>();
}
groupedByClass[result.ClassId].Add(result);
}
List<YOLOV5PolygonClass.DetectionResult> filteredResults = new List<YOLOV5PolygonClass.DetectionResult>();
// 遍历每个类别的结果
foreach(var kvp in groupedByClass)
{
var classResults = kvp.Value;
// 按置信度降序排序
classResults.Sort((a, b) = > b.Confidence.CompareTo(a.Confidence));
while (classResults.Count > 0)
{
// 取出置信度最高的框
var chosenResult = classResults[0];
filteredResults.Add(chosenResult);
// 删除已经选中的框
classResults.RemoveAt(0);
// 对剩余的框计算IoU并删除IoU大于阈值的框
classResults.RemoveAll(otherResult = >
{
double iou = PolygonIoU(chosenResult.BoundingBoxPoints, otherResult.BoundingBoxPoints);
return iou >= iouThreshold;
});
}
}
return filteredResults.ToArray();
}
四、Pose_NMS(掩码无需NMS)
仅对目标检测框进行NMS,关键点坐标无需NMS
C++代码
struct KeyPoint
{
cv::Point2f p;
float prob;
};
struct Object
{
cv::Rect_<float> rect;
int label;
float prob;
std::vector<KeyPoint> keypoints;
};
//=====IOU=====//
static inline float intersection_area(const Object& a, const Object& b)
{
cv::Rect_<float> inter = a.rect & b.rect;
return inter.area();
}
//=====Polygon_NMS=====//
static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
{
picked.clear();
const int n = objects.size();
std::vector<float> areas(n);
for (int i = 0; i < n; i++)
{
areas[i] = objects[i].rect.area();
}
for (int i = 0; i < n; i++)
{
const Object& a = objects[i];
int keep = 1;
for (int j = 0; j < (int)picked.size(); j++)
{
const Object& b = objects[picked[j]];
if (!agnostic && a.label != b.label)
continue;
// intersection over union
float inter_area = intersection_area(a, b);
float union_area = areas[i] + areas[picked[j]] - inter_area;
// float IoU = inter_area / union_area
if (inter_area / union_area > nms_threshold)
keep = 0;
}
if (keep)
picked.push_back(i);
}
}
C#代码
五、Seg_NMS(掩码无需NMS)
仅对目标检测框进行NMS,掩码无需NMS
C++代码:
struct Object
{
cv::Rect_<float> rect;
int label;
float prob;
int gindex;
cv::Mat mask;
};
//=====Seg_IOU=====//
static inline float intersection_area(const Object& a, const Object& b)
{
cv::Rect_<float> inter = a.rect & b.rect;
return inter.area();
}
//=====Seg_NMS=====//
static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
{
picked.clear();
const int n = objects.size();
std::vector<float> areas(n);
for (int i = 0; i < n; i++)
{
areas[i] = objects[i].rect.area();
}
for (int i = 0; i < n; i++)
{
const Object& a = objects[i];
int keep = 1;
for (int j = 0; j < (int)picked.size(); j++)
{
const Object& b = objects[picked[j]];
if (!agnostic && a.label != b.label)
continue;
// intersection over union
float inter_area = intersection_area(a, b);
float union_area = areas[i] + areas[picked[j]] - inter_area;
// float IoU = inter_area / union_area
if (inter_area / union_area > nms_threshold)
keep = 0;
}
if (keep)
picked.push_back(i);
}
}
六、Class_NMS(仅排序,无需NMS)
分类模型没有NMS
struct Object
{
int label;
float prob;
};