深度学习后处理————手撕各种NMS代码

本文介绍Detect_NMS、OBB_NMS、Polygon_NMS、Pose_NMS、Seg_NMS、Class_NMS等等

后续会添加更多任务的NMS

  • Pose_NMS和Seg_NMS只需要对检测边框进行NMS,关键点和掩码无需NMS
  • Class_NMS无需NMS,仅排序

一、Detect_NMS

Detect_NMS是目标检测后处理。

  • C++代码如下:

struct Object
{
    cv::Rect_<float> rect;
    int label;
    float prob;
};

//=====Detect_IOU=====//
static inline float intersection_area(const Object& a, const Object& b)
{
    cv::Rect_<float> inter = a.rect & b.rect;
    return inter.area();
}

//=====Detect_NMS=====//
static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
{
    picked.clear();

    const int n = objects.size();

    std::vector<float> areas(n);
    for (int i = 0; i < n; i++)
    {
        areas[i] = objects[i].rect.area();
    }

    for (int i = 0; i < n; i++)
    {
        const Object& a = objects[i];

        int keep = 1;
        for (int j = 0; j < (int)picked.size(); j++)
        {
            const Object& b = objects[picked[j]];

            if (!agnostic && a.label != b.label)
                continue;

            // intersection over union
            float inter_area = intersection_area(a, b);
            float union_area = areas[i] + areas[picked[j]] - inter_area;
            // float IoU = inter_area / union_area
            if (inter_area / union_area > nms_threshold)
                keep = 0;
        }

        if (keep)
            picked.push_back(i);
    }
}
  • C#代码如下:

public class DetectionResult
{
    public int ClassId{ get; set; }
    public string Class{ get; set; }
    public Rect Rect{ get; set; }  // 多边形的四个顶点
    public float Confidence{ get; set; }

}

//=====Detect_IOU=====//
private static float CalculateIoU(Rect rect1, Rect rect2)
{
    // 计算相交区域
    int x1 = Math.Max(rect1.X, rect2.X);
    int y1 = Math.Max(rect1.Y, rect2.Y);
    int x2 = Math.Min(rect1.X + rect1.Width, rect2.X + rect2.Width);
    int y2 = Math.Min(rect1.Y + rect1.Height, rect2.Y + rect2.Height);
    // 计算相交面积
    int interArea = Math.Max(0, x2 - x1) * Math.Max(0, y2 - y1);
    // 计算各自面积
    float area1 = rect1.Width * rect1.Height;
    float area2 = rect2.Width * rect2.Height;
    // 计算并集面积
    float unionArea = area1 + area2 - interArea;
    // 避免除以零
    return unionArea > float.Epsilon ? interArea / unionArea : 0f;
}


//=====Detect_NMS=====//
public static List<DetectionResult> ApplyNMS(List<DetectionResult> detections, float iouThreshold = 0.5f, float confidenceThreshold = 0.5f)
{
    // 1. 过滤低置信度检测框
    var validDetections = detections.Where(d = > d.Confidence >= confidenceThreshold).ToList();

    // 2. 按类别分组
    var classGroups = validDetections.GroupBy(d = > d.ClassId).ToDictionary(g = > g.Key, g = > g.ToList());

    List<DetectionResult> finalResults = new List<DetectionResult>();

    // 3. 对每个类别单独处理
    foreach(var group in classGroups.Values)
    {
        // 按置信度降序排序
        group.Sort((a, b) = > b.Confidence.CompareTo(a.Confidence));
        while (group.Count > 0)
        {
            // 取出当前最高置信度的检测框
            var current = group[0];
            finalResults.Add(current);
            group.RemoveAt(0);
            // 计算与剩余框的IoU并移除重叠过高的框
            for (int i = group.Count - 1; i >= 0; i--)
            {
                float iou = CalculateIoU(current.Rect, group[i].Rect);
                if (iou > iouThreshold)
                {
                    group.RemoveAt(i);
                }
            }
        }
    }
    return finalResults;
}

或者:

CvDnn.NMSBoxes(results.Select(x => x.Rect), results.Select(x => x.Confidence), scoreThreshold: 0.5f, nmsThreshold: 0.5f, out int[] indices);

二、OBB_NMS

OBB_NMS是旋转目标检测后处理

C++代码

struct Object
{
    cv::RotatedRect rrect;
    int label;
    float prob;
};

//=====OBB_IOU====//
static inline float intersection_area(const Object& a, const Object& b)
{
    std::vector<cv::Point2f> intersection;
    cv::rotatedRectangleIntersection(a.rrect, b.rrect, intersection);
    if (intersection.empty())
        return 0.f;

    return cv::contourArea(intersection);
}


//=====OBB_NMS=====//
static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
{
    picked.clear();

    const int n = objects.size();

    std::vector<float> areas(n);
    for (int i = 0; i < n; i++)
    {
        areas[i] = objects[i].rrect.size.area();
    }

    for (int i = 0; i < n; i++)
    {
        const Object& a = objects[i];

        int keep = 1;
        for (int j = 0; j < (int)picked.size(); j++)
        {
            const Object& b = objects[picked[j]];

            if (!agnostic && a.label != b.label)
                continue;

            // intersection over union
            float inter_area = intersection_area(a, b);
            float union_area = areas[i] + areas[picked[j]] - inter_area;
            // float IoU = inter_area / union_area;
            if (inter_area / union_area > nms_threshold)
                keep = 0;
        }

        if (keep)
            picked.push_back(i);
    }
}

C#代码

public class OBBResult
{
    public int ClassId{ get; set; }
    public string Class{ get; set; }
    public RotatedRect RotaRect{ get; set; }  ////RotatedRect格式为(new Point2f(centerX, centerY)、new Size2f(width, height)、旋转角度)
    public float Confidence{ get; set; }
}

//=====RotatedIOU=======//
private static float CalculateRotatedIoU(RotatedRect rect1, RotatedRect rect2)
{
    // 获取两个旋转矩形的四个顶点
    Point2f[] points1 = rect1.Points();
    Point2f[] points2 = rect2.Points();
    // 计算相交区域面积(直接内联实现)
    float intersectionArea;
    try
    {
        Point2f[] intersectionPoints;
        intersectionArea = Cv2.IntersectConvexConvex(points1, points2, out intersectionPoints);
    }
    catch
    {
        intersectionArea = 0f; // 如果计算出错,返回0面积
    }

    // 计算并集面积
    float area1 = rect1.Size.Width * rect1.Size.Height;
    float area2 = rect2.Size.Width * rect2.Size.Height;
    float unionArea = area1 + area2 - intersectionArea;

    // 避免除以零的情况
    return unionArea > 0 ? intersectionArea / unionArea : 0f;
}


//=====OBB_NMS=======//
public static List<OBBResult> ApplyNMS(List<OBBResult> detections, float iouThreshold = 0.5f, float confidenceThreshold = 0.5f)
{
    // 1. 过滤低置信度检测框
    var validDetections = detections.Where(d = > d.Confidence >= confidenceThreshold).ToList();
    // 2. 按类别分组
    var classGroups = validDetections
        .GroupBy(d = > d.ClassId)
        .ToDictionary(g = > g.Key, g = > g.ToList());
    List<OBBResult> finalResults = new List<OBBResult>();
    // 3. 对每个类别单独处理
    foreach(var group in classGroups.Values)
    {
        // 按置信度降序排序
        group.Sort((a, b) = > b.Confidence.CompareTo(a.Confidence));
        // 当前类别的处理结果
        List<OBBResult> classResults = new List<OBBResult>();
        while (group.Count > 0)
        {
            // 取出当前最高置信度的检测框
            var current = group[0];
            classResults.Add(current);
            group.RemoveAt(0);

            // 计算与剩余框的IoU并移除重叠过高的框
            for (int i = group.Count - 1; i >= 0; i--)
            {
                float iou = CalculateRotatedIoU(current.RotaRect, group[i].RotaRect);
                if (iou > iouThreshold)
                {
                    group.RemoveAt(i);
                }
            }
        }

        finalResults.AddRange(classResults);
    }

    return finalResults;
}

三、Polygon_NMS

多边形检测NMS

C++代码

struct polygon
{
    std::vector<cv::Point2f> boundingBoxPoints; //储存多边形的顶点
    float confidence;   // 置信度
    int classIndex;     // 类别索引
};


//=====Polygon_IOU=====//
float PolygonIoU(const std::vector<cv::Point2f>& PolyBox1, const std::vector<cv::Point2f>& PolyBox2)
{
    if (PolyBox1.size() != 4 || PolyBox2.size() != 4) {
        return 0.0f;
    }

    // 创建 Mat 对象来表示多边形的轮廓
    std::vector<cv::Point> contour1, contour2;
    for (const auto& p : PolyBox1) contour1.push_back(cv::Point(p.x, p.y));
    for (const auto& p : PolyBox2) contour2.push_back(cv::Point(p.x, p.y));

    // 计算两个凸多边形的交集
    std::vector<cv::Point> intersection;
    cv::intersectConvexConvex(contour1, contour2, intersection);  // OpenCV 函数

    // 计算交集的面积
    float intersectionArea = (intersection.size() > 0) ? cv::contourArea(intersection) : 0.0f;

    // 计算两个多边形的面积
    float area1 = cv::contourArea(contour1);
    float area2 = cv::contourArea(contour2);

    // 计算并返回 IoU
    float unionArea = area1 + area2 - intersectionArea;
    return intersectionArea / (unionArea + 1e-6f);  // 防止除以零
}


//=====Polygon_NMS======//
static void PolygonNonMaxSuppression(const std::vector<polygon>& results, std::vector<int>& picked, float iouThreshold)
{
    picked.clear();

    int n = results.size();
    if (n > 300) { n = 300; }

    // 遍历所有对象
    for (int i = 0; i < n; i++) {
        const polygon& a = results[i];

        int keep = 1;
        for (int j = 0; j < (int)picked.size(); j++)
        {
            const polygon& b = results[picked[j]];
            // 如果标签不同且非标签无关模式,则跳过
            if (a.classIndex != b.classIndex)
                continue;
            // 计算多边形的 IoU
            float iou = PolygonIoU(a.boundingBoxPoints, b.boundingBoxPoints); //确保相同类别的边框进行nms
            // 如果 IoU 超过阈值,则不保留
            if (iou >= iouThreshold)
            {
                keep = 0;
                //break;  // 提前退出循环
            }
        }

        if (keep)
        {
            picked.push_back(i);  // 保留该框
            //break; //保留第一个置信度最大的结果(需要确保只有一个类别,不然只会输出第一个结果)
        }
    }
}

C#代码

public class DetectionResult
{
    public Point2f[] BoundingBoxPoints{ get; set; }  // 多边形的四个顶点
    public double Confidence{ get; set; }
    public int ClassId{ get; set; }
}


//=====Polygon_IOU=====//
public static float PolygonIoU(Point2f[] box1, Point2f[] box2)
{
    if (box1 == null || box2 == null || box1.Length != 4 || box2.Length != 4)
        return 0;

    // 创建正确的 Mat,确保其类型为 CV_32FC2(二维浮点型点)
    //Mat contour1 = new Mat(box1.Length, 1, MatType.CV_32FC2, box1.Select(p => new Vec2f(p.X, p.Y)).ToArray());
    //Mat contour2 = new Mat(box2.Length, 1, MatType.CV_32FC2, box2.Select(p => new Vec2f(p.X, p.Y)).ToArray());
    Mat contour1 = new Mat(box1.Length, 1, MatType.CV_32FC2);
    for (int i = 0; i < box1.Length; i++)
    {
        contour1.Set(i, 0, new Vec2f(box1[i].X, box1[i].Y));
    }

    Mat contour2 = new Mat(box2.Length, 1, MatType.CV_32FC2);
    for (int i = 0; i < box2.Length; i++)
    {
        contour2.Set<Vec2f>(i, new Vec2f(box2[i].X, box2[i].Y));
    }

    Mat intersection = new Mat();
    Cv2.IntersectConvexConvex(contour1, contour2, intersection); //计算两个凸多边形的交集

    // 检查是否有交集
    float intersectionArea = intersection.Rows > 0 ? (float)Cv2.ContourArea(intersection) : 0.0f;

    float area1 = (float)Cv2.ContourArea(contour1);
    float area2 = (float)Cv2.ContourArea(contour2);

    // 计算并返回 IoU
    float unionArea = area1 + area2 - intersectionArea;
    return intersectionArea / (unionArea + 1e-6f);  // 防止除以零
}


//=====Polygon_NMS=====//
public static YOLOV5PolygonClass.DetectionResult[] PolygonNonMaxSuppression(YOLOV5PolygonClass.DetectionResult[] results, double iouThreshold)
{
    // 按类别分组
    var groupedByClass = new Dictionary<int, List<YOLOV5PolygonClass.DetectionResult>>();
    foreach(var result in results)
    {
        if (!groupedByClass.ContainsKey(result.ClassId))
        {
            groupedByClass[result.ClassId] = new List<YOLOV5PolygonClass.DetectionResult>();
        }
        groupedByClass[result.ClassId].Add(result);
    }

    List<YOLOV5PolygonClass.DetectionResult> filteredResults = new List<YOLOV5PolygonClass.DetectionResult>();

    // 遍历每个类别的结果
    foreach(var kvp in groupedByClass)
    {
        var classResults = kvp.Value;

        // 按置信度降序排序
        classResults.Sort((a, b) = > b.Confidence.CompareTo(a.Confidence));

        while (classResults.Count > 0)
        {
            // 取出置信度最高的框
            var chosenResult = classResults[0];
            filteredResults.Add(chosenResult);

            // 删除已经选中的框
            classResults.RemoveAt(0);

            // 对剩余的框计算IoU并删除IoU大于阈值的框
            classResults.RemoveAll(otherResult = >
            {
                double iou = PolygonIoU(chosenResult.BoundingBoxPoints, otherResult.BoundingBoxPoints);
                return iou >= iouThreshold;
            });
        }
    }

    return filteredResults.ToArray();
}

四、Pose_NMS(掩码无需NMS)

仅对目标检测框进行NMS,关键点坐标无需NMS

C++代码

struct KeyPoint
{
    cv::Point2f p;
    float prob;
};

struct Object
{
    cv::Rect_<float> rect;
    int label;
    float prob;
    std::vector<KeyPoint> keypoints;
};

//=====IOU=====//
static inline float intersection_area(const Object& a, const Object& b)
{
    cv::Rect_<float> inter = a.rect & b.rect;
    return inter.area();
}

//=====Polygon_NMS=====//
static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
{
    picked.clear();

    const int n = objects.size();

    std::vector<float> areas(n);
    for (int i = 0; i < n; i++)
    {
        areas[i] = objects[i].rect.area();
    }

    for (int i = 0; i < n; i++)
    {
        const Object& a = objects[i];

        int keep = 1;
        for (int j = 0; j < (int)picked.size(); j++)
        {
            const Object& b = objects[picked[j]];

            if (!agnostic && a.label != b.label)
                continue;

            // intersection over union
            float inter_area = intersection_area(a, b);
            float union_area = areas[i] + areas[picked[j]] - inter_area;
            // float IoU = inter_area / union_area
            if (inter_area / union_area > nms_threshold)
                keep = 0;
        }

        if (keep)
            picked.push_back(i);
    }
}

C#代码

五、Seg_NMS(掩码无需NMS)

仅对目标检测框进行NMS,掩码无需NMS

C++代码:

struct Object
{
    cv::Rect_<float> rect;
    int label;
    float prob;
    int gindex;
    cv::Mat mask;
};

//=====Seg_IOU=====//
static inline float intersection_area(const Object& a, const Object& b)
{
    cv::Rect_<float> inter = a.rect & b.rect;
    return inter.area();
}

//=====Seg_NMS=====//
static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
{
    picked.clear();

    const int n = objects.size();

    std::vector<float> areas(n);
    for (int i = 0; i < n; i++)
    {
        areas[i] = objects[i].rect.area();
    }

    for (int i = 0; i < n; i++)
    {
        const Object& a = objects[i];

        int keep = 1;
        for (int j = 0; j < (int)picked.size(); j++)
        {
            const Object& b = objects[picked[j]];

            if (!agnostic && a.label != b.label)
                continue;

            // intersection over union
            float inter_area = intersection_area(a, b);
            float union_area = areas[i] + areas[picked[j]] - inter_area;
            // float IoU = inter_area / union_area
            if (inter_area / union_area > nms_threshold)
                keep = 0;
        }

        if (keep)
            picked.push_back(i);
    }
}

六、Class_NMS(仅排序,无需NMS)

分类模型没有NMS

struct Object
{
    int label;
    float prob;
};

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值