前言
在目标检测推理过程中,经常出现误检的情况。将误检的目标引入样本集中,会大幅降低误检率。
思路
- 找出误检的目标,并将其裁成切片;
- 将负样本切片粘贴到原先的样本集中。
原样本切片
引入负样本后的切片
代码
small_objects_paste.py
import aug as am
import glob
import random
import os
from shp2imagexy import shp2imagexy
def mkdir(path):
if not os.path.exists(path):
os.mkdir(path)
def generate_img(image_dir,
crop_dir,
out_dir,
small_num=2):
mkdir(out_dir)
imglist = glob.glob(f'./{image_dir}/*/*.tif')
small_imglist = glob.glob(f'./{crop_dir}/*.tif')
print(f'the length of img is {len(imglist)}')
print(f'the length of small img is {len(small_imglist)}')
random.shuffle(imglist)
random.shuffle(small_imglist)
for imgPath in imglist:
subRoot = os.path.split(imgPath)[0]
shpPath = glob.glob(f'{subRoot}/*.shp')[0]
small_img = random.sample(small_imglist, small_num)
labels, boxes = shp2imagexy(imgPath, shpPath)
am.copysmallobjects2(imgPath, labels, boxes, out_dir, small_img, class_dict={'1':1})
if __name__ == '__main__':
from shutil import copyfile, rmtree
image_dir = 'playground'
crop_dir = 'playground_neg/images'
out_dir = 'playground_v2'
generate_img(image_dir, crop_dir, out_dir)
filesList = os.listdir(image_dir)
for file in filesList:
file = os.path.join(image_dir, file)
new_file = file.replace('V1', 'V2')
# if os.path.exists(new_file):
# rmtree(new_file)
mkdir(new_file)
subFilesList = os.listdir(file)
for subfile in subFilesList:
suffix = subfile.split('.')[-1]
baseName = subfile.split('.')[0].split('_')[0]
subNewFile = os.path.join(new_file, f'{baseName}_v2.{suffix}')
subfile = os.path.join(file, subfile)
if suffix == 'tif':
subfile = f'{out_dir}/{baseName}.tif'
copyfile(subfile, subNewFile)
else:
copyfile(subfile, subNewFile)
shp2imagexy.py
# -*- coding: utf-8 -*-
from osgeo import ogr
from osgeo import gdal
from osgeo import osr
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import math
def getSRSPair(dataset):
'''
获得给定数据的投影参考系和地理参考系
:param dataset: GDAL地理数据
:return: 投影参考系和地理参考系
'''
prosrs = osr.SpatialReference()
prosrs.ImportFromWkt(dataset.GetProjection())
geosrs = prosrs.CloneGeogCS()
return prosrs, geosrs
def xy_to_coor(x, y):
lonlat_coordinate = []
L = 6381372 * math.pi*2
W = L
H = L/2
mill = 2.3
lat = ((H/2-y)*2*mill)/(1.25*H)
lat = ((math.atan(math.exp(lat))-0.25*math.pi)*180)/(0.4*math.pi)
lon = (x-W/2)*360/W
# TODO 最终需要确认经纬度保留小数点后几位
lonlat_coordinate.append((round(lon,7),round(lat,7)))
return round(lon,7), round(lat,7)
def geo2lonlat(dataset, x, y):
'''
将投影坐标转为经纬度坐标(具体的投影坐标系由给定数据确定)
:param dataset: GDAL地理数据
:param x: 投影坐标x
:param y: 投影坐标y
:return: 投影坐标(x, y)对应的经纬度坐标(lon, lat)
'''
prosrs, geosrs = getSRSPair(dataset)
ct = osr.CoordinateTransformation(prosrs, geosrs)
coords = ct.TransformPoint(x, y)
return coords[:2]
def lonlat2geo(dataset, lon, lat):
'''
将经纬度坐标转为投影坐标(具体的投影坐标系由给定数据确定)
:param dataset: GDAL地理数据
:param lon: 地理坐标lon经度
:param lat: 地理坐标lat纬度
:return: 经纬度坐标(lon, lat)对应的投影坐标
'''
prosrs, geosrs = getSRSPair(dataset)
ct = osr.CoordinateTransformation(geosrs, prosrs)
coords = ct.TransformPoint(lon, lat)
return coords[:2]
def imagexy2geo(dataset, col, row):
'''
根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换)
:param dataset: GDAL地理数据
:param row: 像素的行号
:param col: 像素的列号
:return: 行列号(row, col)对应的投影坐标或地理坐标(x, y)
'''
trans = dataset.GetGeoTransform()
print(trans)
print(row,col)
px = trans[0] + col * trans[1] + row * trans[2]
py = trans[3] + col * trans[4] + row * trans[5]
return px, py
def geo2imagexy01(dataset, x, y):
'''
根据GDAL的六 参数模型将给定的投影或地理坐标转为影像图上坐标(行列号)
:param dataset: GDAL地理数据
:param x: 投影或地理坐标x
:param y: 投影或地理坐标y
:return: 影坐标或地理坐标(x, y)对应的影像图上行列号(row, col)
'''
trans = dataset.GetGeoTransform()
# trans = dataset
a = np.array([[trans[1], trans[2]], [trans[4], trans[5]]])
b = np.array([x - trans[0], y - trans[3]])
return np.linalg.solve(a, b)
def geo2imagexy(dataset, x, y):
'''
根据GDAL的六 参数模型将给定的投影或地理坐标转为影像图上坐标(行列号)
:param dataset: GDAL地理数据
:param x: 投影或地理坐标x
:param y: 投影或地理坐标y
:return: 影坐标或地理坐标(x, y)对应的影像图上行列号(row, col)
'''
trans = dataset.GetGeoTransform()
#a = np.array([[trans[1], trans[2]], [trans[4], trans[5]]])
#b = np.array([x - trans[0], y - trans[3]])
#return np.linalg.solve(a, b) # 使用numpy的linalg.solve进行二元一次方程的求解
dTemp = trans[1] * trans[5] - trans[2] * trans[4]
Xpixel = (trans[5] * (x - trans[0]) - trans[2] * (y - trans[3])) / dTemp
Yline = (trans[1] * (y - trans[3]) - trans[4] * (x - trans[0])) / dTemp
return [Xpixel, Yline]
def shp2imagexy(imgPath, shpPath):
dataset = gdal.Open(imgPath)
ds = ogr.Open(shpPath,1)
if ds is None:
print('Could not open folder')
in_lyr = ds.GetLayer()
lyr_dn = in_lyr.GetLayerDefn()
feature = in_lyr.GetNextFeature()
class_list = []
boxes_list = []
while feature is not None:
geom = feature.geometry()
id = feature.GetField('cls')
arr = np.array(feature.GetGeometryRef().GetEnvelope())
# print('before', arr)
# coordsMin = lonlat2geo(dataset, arr[0], arr[3])
coordsMin = geo2imagexy(dataset, arr[0], arr[3])
# coordsMax = lonlat2geo(dataset, arr[1], arr[2])
coordsMax = geo2imagexy(dataset, arr[1], arr[2])
boxes_list.append([coordsMin[0], coordsMin[1], coordsMax[0], coordsMax[1], id])
class_list.append(1)
feature = in_lyr.GetNextFeature()
return class_list, boxes_list
if __name__ == '__main__':
img_filename = 'E:/2_1/tc/0000000001_V1/0000000001.tif'
dst_filename = 'E:/2_1/tc/0000000001_V1/0000000001_V1_POLY.shp'
finalResult = shp2imagexy(img_filename, dst_filename)
img = cv.imread(img_filename, cv.IMREAD_LOAD_GDAL)
# finalResult = np.array(finalResult)
for bbox in finalResult:
# xmin = min(bbox[0][0], bbox[1][0], bbox[2][0], bbox[3][0])
# ymin = min(bbox[0][1], bbox[1][1], bbox[2][1], bbox[3][1])
# xmax = max(bbox[0][0], bbox[1][0], bbox[2][0], bbox[3][0])
# ymax = max(bbox[0][1], bbox[1][1], bbox[2][1], bbox[3][1])
# cv.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 100, 255), 5)
print(bbox)
cv.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 100, 255), 5)
plt.imshow(img)
plt.show()
aug.py
from lxml import etree
from os.path import basename, split, join, dirname
from util import *
import xml.etree.ElementTree as ET
import gdalTools
def find_str(filename):
if 'train' in filename:
return dirname(filename[filename.find('train'):])
else:
return dirname(filename[filename.find('val'):])
def convert_all_boxes(shape, anno_infos, yolo_label_txt_dir):
height, width, n = shape
label_file = open(yolo_label_txt_dir, 'w')
for anno_info in anno_infos:
target_id, x1, y1, x2, y2 = anno_info
b = (float(x1), float(x2), float(y1), float(y2))
bb = convert((width, height), b)
label_file.write(str(target_id) + " " + " ".join([str(a) for a in bb]) + '\n')
def save_crop_image(save_crop_base_dir, image_dir, idx, roi):
crop_save_dir = join(save_crop_base_dir, find_str(image_dir))
check_dir(crop_save_dir)
crop_img_save_dir = join(crop_save_dir, basename(image_dir)[:-3] + '_crop_' + str(idx) + '.jpg')
cv2.imwrite(crop_img_save_dir, roi)
def copysmallobjects(image_dir, label_dir, save_base_dir, save_crop_base_dir=None,
save_annoation_base_dir=None):
image = cv2.imread(image_dir)
labels = read_label_txt(label_dir)
if len(labels) == 0: return
rescale_labels = rescale_yolo_labels(labels, image.shape) # 转换坐标表示
all_boxes = []
for idx, rescale_label in enumerate(rescale_labels):
all_boxes.append(rescale_label)
# 目标的长宽
rescale_label_height, rescale_label_width = rescale_label[4] - rescale_label[2], rescale_label[3] - \
rescale_label[1]
if (issmallobject((rescale_label_height, rescale_label_width), thresh=64 * 64) and rescale_label[0] == '1'):
roi = image[rescale_label[2]:rescale_label[4], rescale_label[1]:rescale_label[3]]
new_bboxes = random_add_patches(rescale_label, rescale_labels, image.shape, paste_number=2, iou_thresh=0.2)
count = 0
# 将新生成的位置加入到label,并在相应位置画出物体
for new_bbox in new_bboxes:
count += 1
all_boxes.append(new_bbox)
cl, bbox_left, bbox_top, bbox_right, bbox_bottom = new_bbox[0], new_bbox[1], new_bbox[2], new_bbox[3], \
new_bbox[4]
try:
if (count > 1):
roi = flip_bbox(roi)
image[bbox_top:bbox_bottom, bbox_left:bbox_right] = roi
except ValueError:
continue
dir_name = find_str(image_dir)
save_dir = join(save_base_dir, dir_name)
check_dir(save_dir)
yolo_txt_dir = join(save_dir, basename(image_dir.replace('.jpg', '_augment.txt')))
cv2.imwrite(join(save_dir, basename(image_dir).replace('.jpg', '_augment.jpg')), image)
convert_all_boxes(image.shape, all_boxes, yolo_txt_dir)
def GaussianBlurImg(image):
# 高斯模糊
ran = random.randint(0, 9)
if ran % 2 == 1:
image = cv2.GaussianBlur(image, ksize=(ran, ran), sigmaX=0, sigmaY=0)
else:
pass
return image
def suo_fang(image, area_max=1100, area_min=700):
# 改变图片大小
height, width, channels = image.shape
while (height*width) > area_max:
image = cv2.resize(image, (int(width * 0.95),int(height * 0.95)))
height, width, channels = image.shape
height, width = int(height*0.95), int(width*0.95)
while (height*width) < area_min:
image = cv2.resize(image, (int(width * 1.2), int(height * 1.2)))
height, width, channels = image.shape
height, width = int(height*1.2), int(width*1.2)
return image
def parse_xml(path, class_dict):
tree = ET.parse(path)
root = tree.findall('object')
class_list = []
boxes_list = []
difficult_list = []
for sub in root:
xmin = float(sub.find('bndbox').find('xmin').text)
xmax = float(sub.find('bndbox').find('xmax').text)
ymin = float(sub.find('bndbox').find('ymin').text)
ymax = float(sub.find('bndbox').find('ymax').text)
# if ymax > 915 or xmax > 1044 or xmin < 0 or ymin < 0:
# print(xmin, ymin, xmax, ymax, path)
boxes_list.append([xmin, ymin, xmax, ymax])
class_list.append(class_dict[sub.find('name').text])
difficult_list.append(int(sub.find('difficult').text))
return np.array(class_list), np.array(boxes_list).astype(np.int32)
def copysmallobjects2(image_dir, labels, boxes, save_base_dir, small_img_dir, class_dict):
# image = cv2.imread(image_dir)
im_proj, im_geotrans, im_width, im_height, image = gdalTools.read_img(image_dir)
image = image.transpose((1, 2, 0))
if len(labels) == 0:
return
rescale_labels = []
for label, box in zip(labels, boxes):
rescale_labels.append([str(label), int(box[0]), int(box[1]), int(box[2]), int(box[3])])
all_boxes = []
for _, rescale_label in enumerate(rescale_labels):
all_boxes.append(rescale_label)
for small_img_dirs in small_img_dir:
image_bbox = cv2.imread(small_img_dirs)
small_class = os.path.split(small_img_dirs)[-1].split('_')[0]
#roi = image_bbox
try:
# roi = suo_fang(image_bbox, area_max=10000, area_min=1000)
roi = image_bbox
except:
print(small_img_dirs)
continue
new_bboxes = random_add_patches2(roi.shape, small_class, rescale_labels, image.shape, paste_number=1, iou_thresh=0)
count = 0
for new_bbox in new_bboxes:
count += 1
cl, bbox_left, bbox_top, bbox_right, bbox_bottom = new_bbox[0], new_bbox[1], new_bbox[2], new_bbox[3], \
new_bbox[4]
#roi = GaussianBlurImg(roi) # 高斯模糊
height, width, channels = roi.shape
center = (int(width / 2), int(height / 2))
#ran_point = (int((bbox_top+bbox_bottom)/2),int((bbox_left+bbox_right)/2))
mask = 255 * np.ones(roi.shape, roi.dtype)
try:
if count > 1:
roi = flip_bbox(roi)
image[bbox_top:bbox_bottom, bbox_left:bbox_right] = cv2.seamlessClone(roi, image[bbox_top:bbox_bottom, bbox_left:bbox_right],
mask, center, cv2.NORMAL_CLONE)
all_boxes.append(new_bbox)
rescale_labels.append(new_bbox)
except ValueError:
print("---")
continue
dir_name = find_str(image_dir)
save_dir = join(save_base_dir, dir_name)
check_dir(save_dir)
# cv2.imwrite(join(save_dir, basename(image_dir).replace('.jpg', '_augment.jpg')), image)
outfileName = os.path.join(save_dir, basename(image_dir))
gdalTools.write_img(outfileName, im_proj, im_geotrans, image.transpose((2, 0, 1)))
# ann = GEN_Annotations(image_dir)
# ann.set_size(image.shape[0], image.shape[1], image.shape[2])
# for anno_info in all_boxes:
# target_id, x1, y1, x2, y2 = anno_info
# label_name = '1'
# ann.add_pic_attr(label_name, x1, y1, x2, y2)
# ann.savefile(xml_dir)
def get_key(dct, value):
return [k for (k, v) in dct.items() if v == value]
class GEN_Annotations:
def __init__(self, filename):
self.root = etree.Element("annotation")
child1 = etree.SubElement(self.root, "folder")
child1.text = "VOC2007"
child2 = etree.SubElement(self.root, "filename")
child2.text = filename
child3 = etree.SubElement(self.root, "source")
child4 = etree.SubElement(child3, "annotation")
child4.text = "PASCAL VOC2007"
child5 = etree.SubElement(child3, "database")
child5.text = "Unknown"
child6 = etree.SubElement(child3, "image")
child6.text = "flickr"
child7 = etree.SubElement(child3, "flickrid")
child7.text = "35435"
def set_size(self, witdh, height, channel):
size = etree.SubElement(self.root, "size")
widthn = etree.SubElement(size, "width")
widthn.text = str(witdh)
heightn = etree.SubElement(size, "height")
heightn.text = str(height)
channeln = etree.SubElement(size, "depth")
channeln.text = str(channel)
def savefile(self, filename):
tree = etree.ElementTree(self.root)
tree.write(filename, pretty_print=True, xml_declaration=False, encoding='utf-8')
def add_pic_attr(self, label, xmin, ymin, xmax, ymax):
object = etree.SubElement(self.root, "object")
namen = etree.SubElement(object, "name")
namen.text = label
bndbox = etree.SubElement(object, "bndbox")
xminn = etree.SubElement(bndbox, "xmin")
xminn.text = str(xmin)
yminn = etree.SubElement(bndbox, "ymin")
yminn.text = str(ymin)
xmaxn = etree.SubElement(bndbox, "xmax")
xmaxn.text = str(xmax)
ymaxn = etree.SubElement(bndbox, "ymax")
ymaxn.text = str(ymax)
源码地址
https://github.com/SonwYang/GenerativeNegativeSamples