关于目标检测样本集中负样本的引入方法

本文介绍了一种通过检测并裁剪误检目标,将它们融入原始样本集的方法,旨在减少误检率。作者使用shp2imagexy库处理地理坐标和图像切割,以及aug.py中的augmentation技术来生成负样本。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

在目标检测推理过程中,经常出现误检的情况。将误检的目标引入样本集中,会大幅降低误检率。
在这里插入图片描述

思路

  1. 找出误检的目标,并将其裁成切片;
  2. 将负样本切片粘贴到原先的样本集中。

原样本切片
在这里插入图片描述
引入负样本后的切片
在这里插入图片描述

代码

small_objects_paste.py

import aug as am
import glob
import random
import os
from shp2imagexy import shp2imagexy


def mkdir(path):
    if not os.path.exists(path):
        os.mkdir(path)


def generate_img(image_dir,
                 crop_dir,
                 out_dir,
                 small_num=2):
    mkdir(out_dir)
    imglist = glob.glob(f'./{image_dir}/*/*.tif')
    small_imglist = glob.glob(f'./{crop_dir}/*.tif')
    print(f'the length of img is {len(imglist)}')
    print(f'the length of small img is {len(small_imglist)}')
    random.shuffle(imglist)
    random.shuffle(small_imglist)
    for imgPath in imglist:
        subRoot = os.path.split(imgPath)[0]
        shpPath = glob.glob(f'{subRoot}/*.shp')[0]
        small_img = random.sample(small_imglist, small_num)
        labels, boxes = shp2imagexy(imgPath, shpPath)
        am.copysmallobjects2(imgPath, labels, boxes, out_dir, small_img, class_dict={'1':1})


if __name__ == '__main__':
    from shutil import copyfile, rmtree
    image_dir = 'playground'
    crop_dir = 'playground_neg/images'
    out_dir = 'playground_v2'
    generate_img(image_dir, crop_dir, out_dir)
    filesList = os.listdir(image_dir)
    for file in filesList:
        file = os.path.join(image_dir, file)
        new_file = file.replace('V1', 'V2')
        # if os.path.exists(new_file):
        #     rmtree(new_file)
        mkdir(new_file)
        subFilesList = os.listdir(file)
        for subfile in subFilesList:
            suffix = subfile.split('.')[-1]
            baseName = subfile.split('.')[0].split('_')[0]
            subNewFile = os.path.join(new_file, f'{baseName}_v2.{suffix}')
            subfile = os.path.join(file, subfile)
            if suffix == 'tif':
                subfile = f'{out_dir}/{baseName}.tif'
                copyfile(subfile, subNewFile)
            else:
                copyfile(subfile, subNewFile)

shp2imagexy.py

# -*- coding: utf-8 -*-
from osgeo import ogr
from osgeo import gdal
from osgeo import osr
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import math


def getSRSPair(dataset):
    '''
    获得给定数据的投影参考系和地理参考系
    :param dataset: GDAL地理数据
    :return: 投影参考系和地理参考系
    '''
    prosrs = osr.SpatialReference()
    prosrs.ImportFromWkt(dataset.GetProjection())
    geosrs = prosrs.CloneGeogCS()
    return prosrs, geosrs


def xy_to_coor(x, y):
    lonlat_coordinate = []
    L = 6381372 * math.pi*2
    W = L
    H = L/2
    mill = 2.3
    lat = ((H/2-y)*2*mill)/(1.25*H)
    lat = ((math.atan(math.exp(lat))-0.25*math.pi)*180)/(0.4*math.pi)
    lon = (x-W/2)*360/W
    # TODO 最终需要确认经纬度保留小数点后几位
    lonlat_coordinate.append((round(lon,7),round(lat,7)))
    return round(lon,7), round(lat,7)


def geo2lonlat(dataset, x, y):
    '''
    将投影坐标转为经纬度坐标(具体的投影坐标系由给定数据确定)
    :param dataset: GDAL地理数据
    :param x: 投影坐标x
    :param y: 投影坐标y
    :return: 投影坐标(x, y)对应的经纬度坐标(lon, lat)
    '''
    prosrs, geosrs = getSRSPair(dataset)
    ct = osr.CoordinateTransformation(prosrs, geosrs)
    coords = ct.TransformPoint(x, y)
    return coords[:2]


def lonlat2geo(dataset, lon, lat):
    '''
    将经纬度坐标转为投影坐标(具体的投影坐标系由给定数据确定)
    :param dataset: GDAL地理数据
    :param lon: 地理坐标lon经度
    :param lat: 地理坐标lat纬度
    :return: 经纬度坐标(lon, lat)对应的投影坐标
    '''
    prosrs, geosrs = getSRSPair(dataset)
    ct = osr.CoordinateTransformation(geosrs, prosrs)
    coords = ct.TransformPoint(lon, lat)
    return coords[:2]


def imagexy2geo(dataset, col, row):
    '''
    根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换)
    :param dataset: GDAL地理数据
    :param row: 像素的行号
    :param col: 像素的列号
    :return: 行列号(row, col)对应的投影坐标或地理坐标(x, y)
    '''
    trans = dataset.GetGeoTransform()
    print(trans)
    print(row,col)
    px = trans[0] + col * trans[1] + row * trans[2]
    py = trans[3] + col * trans[4] + row * trans[5]
    return px, py


def geo2imagexy01(dataset, x, y):
    '''
    根据GDAL的六 参数模型将给定的投影或地理坐标转为影像图上坐标(行列号)
    :param dataset: GDAL地理数据
    :param x: 投影或地理坐标x
    :param y: 投影或地理坐标y
    :return: 影坐标或地理坐标(x, y)对应的影像图上行列号(row, col)
    '''
    trans = dataset.GetGeoTransform()
    # trans = dataset
    a = np.array([[trans[1], trans[2]], [trans[4], trans[5]]])
    b = np.array([x - trans[0], y - trans[3]])
    return np.linalg.solve(a, b)

def geo2imagexy(dataset, x, y):
    '''
    根据GDAL的六 参数模型将给定的投影或地理坐标转为影像图上坐标(行列号)
    :param dataset: GDAL地理数据
    :param x: 投影或地理坐标x
    :param y: 投影或地理坐标y
    :return: 影坐标或地理坐标(x, y)对应的影像图上行列号(row, col)
    '''
    trans = dataset.GetGeoTransform()
    #a = np.array([[trans[1], trans[2]], [trans[4], trans[5]]])
    #b = np.array([x - trans[0], y - trans[3]])
    #return np.linalg.solve(a, b)  # 使用numpy的linalg.solve进行二元一次方程的求解

    dTemp = trans[1] * trans[5] - trans[2] * trans[4]
    Xpixel = (trans[5] * (x - trans[0]) - trans[2] * (y - trans[3])) / dTemp
    Yline = (trans[1] * (y - trans[3]) - trans[4] * (x - trans[0])) / dTemp
    return [Xpixel, Yline]


def shp2imagexy(imgPath, shpPath):
    dataset = gdal.Open(imgPath)
    ds = ogr.Open(shpPath,1)
    if ds is None:
        print('Could not open folder')
    in_lyr = ds.GetLayer()

    lyr_dn = in_lyr.GetLayerDefn()
    feature = in_lyr.GetNextFeature()
    class_list = []
    boxes_list = []
    while feature is not None:

        geom = feature.geometry()
        id = feature.GetField('cls')
        arr = np.array(feature.GetGeometryRef().GetEnvelope())
        # print('before', arr)
        # coordsMin = lonlat2geo(dataset, arr[0], arr[3])
        coordsMin = geo2imagexy(dataset, arr[0], arr[3])
        # coordsMax = lonlat2geo(dataset, arr[1], arr[2])
        coordsMax = geo2imagexy(dataset, arr[1], arr[2])
        boxes_list.append([coordsMin[0], coordsMin[1], coordsMax[0], coordsMax[1], id])
        class_list.append(1)
        feature = in_lyr.GetNextFeature()

    return class_list, boxes_list


if __name__ == '__main__':
    img_filename = 'E:/2_1/tc/0000000001_V1/0000000001.tif'
    dst_filename = 'E:/2_1/tc/0000000001_V1/0000000001_V1_POLY.shp'
    finalResult = shp2imagexy(img_filename, dst_filename)
    img = cv.imread(img_filename, cv.IMREAD_LOAD_GDAL)
    # finalResult = np.array(finalResult)
    for bbox in finalResult:
        # xmin = min(bbox[0][0], bbox[1][0], bbox[2][0], bbox[3][0])
        # ymin = min(bbox[0][1], bbox[1][1], bbox[2][1], bbox[3][1])
        # xmax = max(bbox[0][0], bbox[1][0], bbox[2][0], bbox[3][0])
        # ymax = max(bbox[0][1], bbox[1][1], bbox[2][1], bbox[3][1])
        # cv.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 100, 255), 5)
        print(bbox)
        cv.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 100, 255), 5)

    plt.imshow(img)
    plt.show()

aug.py

from lxml import etree
from os.path import basename, split, join, dirname
from util import *
import xml.etree.ElementTree as ET
import gdalTools

def find_str(filename):
    if 'train' in filename:
        return dirname(filename[filename.find('train'):])
    else:
        return dirname(filename[filename.find('val'):])


def convert_all_boxes(shape, anno_infos, yolo_label_txt_dir):
    height, width, n = shape
    label_file = open(yolo_label_txt_dir, 'w')
    for anno_info in anno_infos:
        target_id, x1, y1, x2, y2 = anno_info
        b = (float(x1), float(x2), float(y1), float(y2))
        bb = convert((width, height), b)
        label_file.write(str(target_id) + " " + " ".join([str(a) for a in bb]) + '\n')


def save_crop_image(save_crop_base_dir, image_dir, idx, roi):
    crop_save_dir = join(save_crop_base_dir, find_str(image_dir))
    check_dir(crop_save_dir)
    crop_img_save_dir = join(crop_save_dir, basename(image_dir)[:-3] + '_crop_' + str(idx) + '.jpg')
    cv2.imwrite(crop_img_save_dir, roi)


def copysmallobjects(image_dir, label_dir, save_base_dir, save_crop_base_dir=None,
                     save_annoation_base_dir=None):
    image = cv2.imread(image_dir)

    labels = read_label_txt(label_dir)
    if len(labels) == 0: return
    rescale_labels = rescale_yolo_labels(labels, image.shape)  # 转换坐标表示
    all_boxes = []

    for idx, rescale_label in enumerate(rescale_labels):

        all_boxes.append(rescale_label)
        # 目标的长宽
        rescale_label_height, rescale_label_width = rescale_label[4] - rescale_label[2], rescale_label[3] - \
                                                    rescale_label[1]

        if (issmallobject((rescale_label_height, rescale_label_width), thresh=64 * 64) and rescale_label[0] == '1'):
            roi = image[rescale_label[2]:rescale_label[4], rescale_label[1]:rescale_label[3]]

            new_bboxes = random_add_patches(rescale_label, rescale_labels, image.shape, paste_number=2, iou_thresh=0.2)
            count = 0

            # 将新生成的位置加入到label,并在相应位置画出物体
            for new_bbox in new_bboxes:
                count += 1
                all_boxes.append(new_bbox)
                cl, bbox_left, bbox_top, bbox_right, bbox_bottom = new_bbox[0], new_bbox[1], new_bbox[2], new_bbox[3], \
                                                                   new_bbox[4]
                try:
                    if (count > 1):
                        roi = flip_bbox(roi)
                    image[bbox_top:bbox_bottom, bbox_left:bbox_right] = roi
                except ValueError:
                    continue

    dir_name = find_str(image_dir)
    save_dir = join(save_base_dir, dir_name)
    check_dir(save_dir)
    yolo_txt_dir = join(save_dir, basename(image_dir.replace('.jpg', '_augment.txt')))
    cv2.imwrite(join(save_dir, basename(image_dir).replace('.jpg', '_augment.jpg')), image)
    convert_all_boxes(image.shape, all_boxes, yolo_txt_dir)


def GaussianBlurImg(image):
    # 高斯模糊
    ran = random.randint(0, 9)
    if ran % 2 == 1:
        image = cv2.GaussianBlur(image, ksize=(ran, ran), sigmaX=0, sigmaY=0)
    else:
        pass
    return image


def suo_fang(image, area_max=1100, area_min=700):
    # 改变图片大小
    height, width, channels = image.shape

    while (height*width) > area_max:
        image = cv2.resize(image, (int(width * 0.95),int(height * 0.95)))
        height, width, channels = image.shape
        height, width = int(height*0.95), int(width*0.95)

    while (height*width) < area_min:
        image = cv2.resize(image, (int(width * 1.2), int(height * 1.2)))
        height, width, channels = image.shape
        height, width = int(height*1.2), int(width*1.2)

    return image


def parse_xml(path, class_dict):
    tree = ET.parse(path)
    root = tree.findall('object')
    class_list = []
    boxes_list = []
    difficult_list = []
    for sub in root:
        xmin = float(sub.find('bndbox').find('xmin').text)
        xmax = float(sub.find('bndbox').find('xmax').text)
        ymin = float(sub.find('bndbox').find('ymin').text)
        ymax = float(sub.find('bndbox').find('ymax').text)
        # if ymax > 915 or xmax > 1044 or xmin < 0 or ymin < 0:
        #     print(xmin, ymin, xmax, ymax, path)
        boxes_list.append([xmin, ymin, xmax, ymax])
        class_list.append(class_dict[sub.find('name').text])
        difficult_list.append(int(sub.find('difficult').text))
    return np.array(class_list), np.array(boxes_list).astype(np.int32)


def copysmallobjects2(image_dir, labels, boxes, save_base_dir, small_img_dir, class_dict):
    # image = cv2.imread(image_dir)
    im_proj, im_geotrans, im_width, im_height, image = gdalTools.read_img(image_dir)
    image = image.transpose((1, 2, 0))
    if len(labels) == 0:
        return
    rescale_labels = []
    for label, box in zip(labels, boxes):
        rescale_labels.append([str(label), int(box[0]), int(box[1]), int(box[2]), int(box[3])])

    all_boxes = []
    for _, rescale_label in enumerate(rescale_labels):
        all_boxes.append(rescale_label)

    for small_img_dirs in small_img_dir:
        image_bbox = cv2.imread(small_img_dirs)
        small_class = os.path.split(small_img_dirs)[-1].split('_')[0]
        #roi = image_bbox
        try:
             # roi = suo_fang(image_bbox, area_max=10000, area_min=1000)
            roi = image_bbox
        except:
            print(small_img_dirs)
            continue

        new_bboxes = random_add_patches2(roi.shape, small_class, rescale_labels, image.shape, paste_number=1, iou_thresh=0)
        count = 0
        for new_bbox in new_bboxes:
            count += 1

            cl, bbox_left, bbox_top, bbox_right, bbox_bottom = new_bbox[0], new_bbox[1], new_bbox[2], new_bbox[3], \
                                                               new_bbox[4]
            #roi = GaussianBlurImg(roi)  # 高斯模糊
            height, width, channels = roi.shape
            center = (int(width / 2), int(height / 2))
            #ran_point = (int((bbox_top+bbox_bottom)/2),int((bbox_left+bbox_right)/2))
            mask = 255 * np.ones(roi.shape, roi.dtype)

            try:
                if count > 1:
                    roi = flip_bbox(roi)

                image[bbox_top:bbox_bottom, bbox_left:bbox_right] = cv2.seamlessClone(roi, image[bbox_top:bbox_bottom, bbox_left:bbox_right],
                                                                                      mask, center, cv2.NORMAL_CLONE)
                all_boxes.append(new_bbox)
                rescale_labels.append(new_bbox)
            except ValueError:
                print("---")
                continue

    dir_name = find_str(image_dir)
    save_dir = join(save_base_dir, dir_name)
    check_dir(save_dir)
    # cv2.imwrite(join(save_dir, basename(image_dir).replace('.jpg', '_augment.jpg')), image)
    outfileName = os.path.join(save_dir, basename(image_dir))
    gdalTools.write_img(outfileName, im_proj, im_geotrans, image.transpose((2, 0, 1)))
    # ann = GEN_Annotations(image_dir)
    # ann.set_size(image.shape[0], image.shape[1], image.shape[2])
    # for anno_info in all_boxes:
    #     target_id, x1, y1, x2, y2 = anno_info
    #     label_name = '1'
    #     ann.add_pic_attr(label_name, x1, y1, x2, y2)
    # ann.savefile(xml_dir)


def get_key(dct, value):
    return [k for (k, v) in dct.items() if v == value]


class GEN_Annotations:
    def __init__(self, filename):
        self.root = etree.Element("annotation")

        child1 = etree.SubElement(self.root, "folder")
        child1.text = "VOC2007"

        child2 = etree.SubElement(self.root, "filename")
        child2.text = filename

        child3 = etree.SubElement(self.root, "source")

        child4 = etree.SubElement(child3, "annotation")
        child4.text = "PASCAL VOC2007"
        child5 = etree.SubElement(child3, "database")
        child5.text = "Unknown"

        child6 = etree.SubElement(child3, "image")
        child6.text = "flickr"
        child7 = etree.SubElement(child3, "flickrid")
        child7.text = "35435"

    def set_size(self, witdh, height, channel):
        size = etree.SubElement(self.root, "size")
        widthn = etree.SubElement(size, "width")
        widthn.text = str(witdh)
        heightn = etree.SubElement(size, "height")
        heightn.text = str(height)
        channeln = etree.SubElement(size, "depth")
        channeln.text = str(channel)

    def savefile(self, filename):
        tree = etree.ElementTree(self.root)
        tree.write(filename, pretty_print=True, xml_declaration=False, encoding='utf-8')

    def add_pic_attr(self, label, xmin, ymin, xmax, ymax):
        object = etree.SubElement(self.root, "object")
        namen = etree.SubElement(object, "name")
        namen.text = label
        bndbox = etree.SubElement(object, "bndbox")
        xminn = etree.SubElement(bndbox, "xmin")
        xminn.text = str(xmin)
        yminn = etree.SubElement(bndbox, "ymin")
        yminn.text = str(ymin)
        xmaxn = etree.SubElement(bndbox, "xmax")
        xmaxn.text = str(xmax)
        ymaxn = etree.SubElement(bndbox, "ymax")
        ymaxn.text = str(ymax)

源码地址

https://github.com/SonwYang/GenerativeNegativeSamples

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

点PY

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值