TFRecord生成
为什么使用TFRecord?
正常情况下我们训练文件夹经常会生成 train, test 或者val文件夹,这些文件夹内部往往会存着成千上万的图片或文本等文件,这些文件被散列存着,这样不仅占用磁盘空间,并且再被一个个读取的时候会非常慢,繁琐。占用大量内存空间(有的大型数据不足以一次性加载)。此时我们TFRecord格式的文件存储形式会很合理的帮我们存储数据。TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。
生成TFRecord简单实现方式
我们可以分成两个部分来介绍如何生成TFRecord,分别是
- TFRecord生成器
- Example(样本)模块。
TFRecord生成器
writer = tf.python_io.TFRecordWriter(record_path)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
这里面writer就是我们TFrecord生成器。接着我们就可以通过writer.write(tf_example.SerializeToString())来生成我们所要的tfrecord文件了。这里需要注意的是我们TFRecord生成器在写完文件后需要关闭writer.close()。这里tf_example.SerializeToString()是将Example中的map压缩为二进制文件,更好的节省空间。那么tf_example是如何生成的呢?
那就是下面所要介绍的样本Example模块了。
Example模块
首先们来看一下Example协议块是什么样子的。
message Example {
Features features = 1;
};
message Features {
map<string, Feature> feature = 1;
};
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
我们可以看出上面的tf_example可以写入的数据形式有三种,分别是BytesList, FloatList以及Int64List的类型。那我们如何写一个tf_example呢?下面有一个简单的例子。
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
tf_example = tf.train.Example(
features=tf.train.Features(feature={
'image/encoded': bytes_feature(encoded_jpg),
'image/format': bytes_feature('jpg'.encode()),
'image/class/label': int64_feature(label),
'image/height': int64_feature(height),
'image/width': int64_feature(width)}))
下面我们来好好从外部往内部分解来解释一下上面的内容。
(1)tf.train.Example(features = None) 这里的features是tf.train.Features类型的特征实例。
(2)tf.train.Features(feature = None) 这里的feature是以字典的形式存在,*key:要保存数据的名字 value:要保存的数据,但是格式必须符合tf.train.Feature实例要求。
TFRecord读取
上面我们介绍了如何生成TFRecord,现在我们尝试如何通过使用队列读取读取我们的TFRecord。
读取TFRecord可以通过tensorflow两个重要的函数实现
- tf.train.string_input_producer
- tf.TFRecordReader 的 tf.parse_single_example解析器
如下图
读取TFRecord的简单实现方式
解析TFRecord有两种解析方式一种是利用tf.parse_single_example, 另一种是通过tf.contrib.slim(推荐使用)。
- 第一种方式(tf.parse_single_example)解析步骤如下:
(1).第一步,我们将train.record文件读入到队列中,如下所示:
filename_queue = tf.train.string_input_producer([tfrecords_filename])
(2) 第二步,我们需要通过TFRecord将生成的队列读入
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
(3)第三步, 通过解析器tf.parse_single_example将我们的example解析出来。
- 第二种方式(tf.contrib.slim)解析步骤如下:
(1)第一步, 我们要设置decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers), 其中key_to_features这个字典需要和TFrecord文件中定义的字典项匹配,items_to_handlers中的关键字可以是任意值,但是它的handler的初始化参数必须要来自于keys_to_features中的关键字。
(2) 第二步, 我们要设定dataset = slim.dataset.Dataset(params), 其中params包括:
a. data_source: 为tfrecord文件地址
b. reader: 一般设置为tf.TFRecordReader阅读器
c. decoder: 为第一步设置的decoder
d. num_samples: 样本数量
e. items_to_description: 对样本及标签的描述
f. num_classes: 分类的数量
(3) 第三步, 我们设置provider = slim.dataset_data_provider.DatasetDataProvider(params), 其中params包括 :
a. dataset: 第二步骤我们生成的数据集
b. num_reader: 并行阅读器数量
c. shuffle: 是否打乱
d. num_epochs:每个数据源被读取的次数,如果设为None数据将会被无限循环的读取
e. common_queue_capacity:读取数据队列的容量,默认为256
f. scope:范围
g. common_queue_min:读取数据队列的最小容量。
(4) 第四步, 我们可以通过provider.get得到我们需要的数据了。
对不同图片大小的TFRecord读取并resize成相同大小
reshape_same_size函数 来对图片进行resize,这样我们可以对我们的图片进行batch操作了,因为有的神经网络训练需要一个batch一个batch操作,不同大小的图片在组成一个batch的时候会报错,因此我们我通过后期处理可以更好的对图片进行batch操作。
或者直接通过resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[FLAG.resize_height, FLAG.resize_width]))即可。
更多请参考
https://www.jianshu.com/p/b480e5fcb638