人工智能学习64-Yolo样本处理—快手视频
人工智能学习65-Yolo样本处理—快手视频
Yolo算法样本格式
Yolo算法样本数据采用VOC数据规范,初学者可以直接下载使用,也可以按照VOC数据规范补充自己的样本数据,训练自己的Yolo网络模型,目录结构如下图:
目录VOC2007/Annotations中存放XML描述文件,XML文件名称与图片文件同名,XML文件内容如下图:
目录VOC2007/JPEGImages中存放图片文件,图片文件扩展名为jpg,图片文件大小不限制,网络训练时统一转化为416*416。
目录VOC2007/ImageSets/Main中存放训练文件、验证文件、训练测试文件的索引信息,如下图:
物体种类配置文件
Yolo算法可以训练两类模型,一是20类物种模型,二是80类物种模型。20类物种模型配置文件在/model_data/voc_classes.txt;80类物种模型配置文件在/model_data/coco_classes.txt。如果新增自己样本,需要修改这两个文件内容。
补充样本数据步骤
将图片文件复制到VOC2007/JPEGImages目录下,文件扩展名必须为jpg。在VOC2007/Annotations目录新建与图片文件同名的XML文件,按照如下数据修改XML文件内容。
<annotation>
<folder>VOC2007</folder>
<filename>000001.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
<flickrid>341012865</flickrid>
</source>
<owner>
<flickrid>Fried Camels</flickrid>
<name>Jinky the Fruit Bat</name>
</owner>
<size>
<width>353</width>
<height>500</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>dog</name>
<pose>Left</pose>
<truncated>1</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>48</xmin>
<ymin>240</ymin>
<xmax>195</xmax>
<ymax>371</ymax>
</bndbox>
</object>
<object>
<name>person</name>
<pose>Left</pose>
<truncated>1</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>8</xmin>
<ymin>12</ymin>
<xmax>352</xmax>
<ymax>498</ymax>
</bndbox>
</object>
</annotation>
使用Window画板或其他工具,在新增图片中使用Box标明物体坐标与种类,如下图:
Box框定物种分类需要从/model_data/coco_classes.txt或/model_data/voc_classes.txt查找,如果是新增物种,需要修改这两个配置文件。
使用脚本生成训练和测试配置文件
运行脚本voc_annotation.py,在根目录下生成两个配置文件2007_train.txt和2007_val.txt。
voc_annotation.py文件内容
import os
import random
import xml.etree.ElementTree as ET
import numpy as np
from utils import get_classes
#--------------------------------------------------------------------------------------------------------------------------------#
# annotation_mode用于指定该文件运行时计算的内容
# annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
# annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
# annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
#--------------------------------------------------------------------------------------------------------------------------------#
annotation_mode = 0
#-------------------------------------------------------------------#
# 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
# 与训练和预测所用的classes_path一致即可
# 如果生成的2007_train.txt里面没有目标信息
# 那么就是因为classes没有设定正确
# 仅在annotation_mode为0和2的时候有效
#-------------------------------------------------------------------#
classes_path = '../model_data/coco_classes.txt' #原来是:voc_classes.txt, coco_classes.txt
#--------------------------------------------------------------------------------------------------------------------------------#
# trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
# train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
# 仅在annotation_mode为0和1的时候有效
#--------------------------------------------------------------------------------------------------------------------------------#
trainval_percent = 0.9
train_percent = 0.9
#-------------------------------------------------------#
# 指向VOC数据集所在的文件夹
# 默认指向根目录下的VOC数据集
#-------------------------------------------------------#
VOCdevkit_path = '../VOCdevkit'
VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')]
classes, _ = get_classes(classes_path)
#-------------------------------------------------------#
# 统计目标数量
#-------------------------------------------------------#
photo_nums = np.zeros(len(VOCdevkit_sets))
nums = np.zeros(len(classes))
def convert_annotation(year, image_id, list_file):
in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
tree=ET.parse(in_file)
root = tree.getroot()
for obj in root.iter('object'):
difficult = 0
if obj.find('difficult')!=None:
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult)==1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
nums[classes.index(cls)] = nums[classes.index(cls)] + 1
if __name__ == "__main__":
random.seed(0)
if " " in os.path.abspath(VOCdevkit_path):
raise ValueError("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。")
if annotation_mode == 0 or annotation_mode == 1:
print("Generate txt in ImageSets.")
xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:
if xml.endswith(".xml"):
total_xml.append(xml)
num = len(total_xml)
list = range(num)
tv = int(num*trainval_percent)
tr = int(tv*train_percent)
trainval= random.sample(list,tv)
train = random.sample(trainval,tr)
print("train and val size",tv)
print("train size",tr)
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
for i in list:
name=total_xml[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
print("Generate txt in ImageSets done.")
if annotation_mode == 0 or annotation_mode == 2:
print("Generate 2007_train.txt and 2007_val.txt for train.")
type_index = 0
for year, image_set in VOCdevkit_sets:
image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
list_file = open('../%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
for image_id in image_ids:
list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))
convert_annotation(year, image_id, list_file)
list_file.write('\n')
photo_nums[type_index] = len(image_ids)
type_index += 1
list_file.close()
print("Generate 2007_train.txt and 2007_val.txt for train done.")
def printTable(List1, List2):
for i in range(len(List1[0])):
print("|", end=' ')
for j in range(len(List1)):
print(List1[j][i].rjust(int(List2[j])), end=' ')
print("|", end=' ')
print()
str_nums = [str(int(x)) for x in nums]
tableData = [
classes, str_nums
]
colWidths = [0]*len(tableData)
len1 = 0
for i in range(len(tableData)):
for j in range(len(tableData[i])):
if len(tableData[i][j]) > colWidths[i]:
colWidths[i] = len(tableData[i][j])
printTable(tableData, colWidths)
if photo_nums[0] <= 500:
print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。")
if np.sum(nums) == 0:
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
print("(重要的事情说三遍)。")
2007_train.txt文件内容
E:\HK_SYS\KerasGit\yolo3-keras\VOCdevkit/VOC2007/JPEGImages/000001.jpg 48,240,195,371,16 8,12,352,498,0
E:\HK_SYS\KerasGit\yolo3-keras\VOCdevkit/VOC2007/JPEGImages/000002.jpg 139,200,207,301,6
E:\HK_SYS\KerasGit\yolo3-keras\VOCdevkit/VOC2007/JPEGImages/000003.jpg 123,155,215,195,57 239,156,307,205,56
E:\HK_SYS\KerasGit\yolo3-keras\VOCdevkit/VOC2007/JPEGImages/000005.jpg 263,211,324,339,56 165,264,253,372,56 241,194,295,299,56
E:\HK_SYS\KerasGit\yolo3-keras\VOCdevkit/VOC2007/JPEGImages/000006.jpg 187,135,282,242,58 154,209,369,375,60 255,207,366,375,56 138,211,249,375,56
E:\HK_SYS\KerasGit\yolo3-keras\VOCdevkit/VOC2007/JPEGImages/000007.jpg 141,50,500,330,2
E:\HK_SYS\KerasGit\yolo3-keras\VOCdevkit/VOC2007/JPEGImages/000008.jpg 192,16,364,249,56
……
Yolo算法的标签数据
根据样本标签数据生成Yolo网络训练时使用image_data,y_true数据。
image_data是图片二进制编码数据,数据格式(n,w,h,c)=(16, 416, 416, 3)。其中16是图片数量,每次取16幅图片用于训练网络模型,416*416分别是图片宽度和高度,3代表图片的RGBs三个颜色通道。
y_true是标签数据,数据格式(n,w,h,boxes,classes)=(16,13,13,3,85)、(16,26,26,3,85)、(16,52,52,3,85)。其中16是图片数量,训练样本每幅图片都有标签值,1313、2626、52*52代表三个不同大小的特征图,boxes=3代表每个尺寸的特征图设定3个先验框box,先验框配置文件在/model_data/yolo_anchors.txt中设置。Classes=85代表80类物体种类和坐标信息,前五个元素分别是物体x、y、w、h、confidence,后80个元素是80类物体其中之一的one-hot编码,具体数据格式如下: