TFRecord_CodingPark编程公园

TFRecord是TensorFlow提供的数据存储格式,用于优化大规模数据的存储和读取。它采用ProtocolBuffer二进制编码,减少磁盘占用并提升读取速度。生成TFRecord文件涉及创建TFRecordWriter,使用Example模块构造数据。Example支持BytesList, FloatList, Int64List类型数据。读取时,可通过tf.train.string_input_producer和tf.parse_single_example或slim库进行解析。对于不同尺寸的图片,可以resize成统一大小以便批量处理。TFRecord在处理大量训练数据时能显著提高效率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

TFRecord生成

为什么使用TFRecord?

正常情况下我们训练文件夹经常会生成 train, test 或者val文件夹,这些文件夹内部往往会存着成千上万的图片或文本等文件,这些文件被散列存着,这样不仅占用磁盘空间,并且再被一个个读取的时候会非常慢,繁琐。占用大量内存空间(有的大型数据不足以一次性加载)。此时我们TFRecord格式的文件存储形式会很合理的帮我们存储数据。TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率

生成TFRecord简单实现方式

我们可以分成两个部分来介绍如何生成TFRecord,分别是

  • TFRecord生成器
  • Example(样本)模块。

TFRecord生成器

writer = tf.python_io.TFRecordWriter(record_path)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()

这里面writer就是我们TFrecord生成器。接着我们就可以通过writer.write(tf_example.SerializeToString())来生成我们所要的tfrecord文件了。这里需要注意的是我们TFRecord生成器在写完文件后需要关闭writer.close()。这里tf_example.SerializeToString()是将Example中的map压缩为二进制文件,更好的节省空间。那么tf_example是如何生成的呢?

那就是下面所要介绍的样本Example模块了。

Example模块

首先们来看一下Example协议块是什么样子的。

message Example {
  Features features = 1;
};

message Features {
  map<string, Feature> feature = 1;
};

message Feature {
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

我们可以看出上面的tf_example可以写入的数据形式有三种,分别是BytesList, FloatList以及Int64List的类型。那我们如何写一个tf_example呢?下面有一个简单的例子。

def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

tf_example = tf.train.Example(
        features=tf.train.Features(feature={
            'image/encoded': bytes_feature(encoded_jpg),
            'image/format': bytes_feature('jpg'.encode()),
            'image/class/label': int64_feature(label),
            'image/height': int64_feature(height),
            'image/width': int64_feature(width)}))

下面我们来好好从外部往内部分解来解释一下上面的内容。
(1)tf.train.Example(features = None) 这里的features是tf.train.Features类型的特征实例。
(2)tf.train.Features(feature = None) 这里的feature是以字典的形式存在,*key:要保存数据的名字 value:要保存的数据,但是格式必须符合tf.train.Feature实例要求。


TFRecord读取

上面我们介绍了如何生成TFRecord,现在我们尝试如何通过使用队列读取读取我们的TFRecord。
读取TFRecord可以通过tensorflow两个重要的函数实现

  • tf.train.string_input_producer
  • tf.TFRecordReadertf.parse_single_example解析器

如下图
在这里插入图片描述

读取TFRecord的简单实现方式

解析TFRecord有两种解析方式一种是利用tf.parse_single_example, 另一种是通过tf.contrib.slim(推荐使用)。

  1. 第一种方式(tf.parse_single_example)解析步骤如下:
    (1).第一步,我们将train.record文件读入到队列中,如下所示:
filename_queue = tf.train.string_input_producer([tfrecords_filename])

(2) 第二步,我们需要通过TFRecord将生成的队列读入

reader = tf.TFRecordReader()
 _, serialized_example = reader.read(filename_queue) #返回文件名和文件

(3)第三步, 通过解析器tf.parse_single_example将我们的example解析出来。

  1. 第二种方式(tf.contrib.slim)解析步骤如下:
    (1)第一步, 我们要设置decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers), 其中key_to_features这个字典需要和TFrecord文件中定义的字典项匹配,items_to_handlers中的关键字可以是任意值,但是它的handler的初始化参数必须要来自于keys_to_features中的关键字。
    (2) 第二步, 我们要设定dataset = slim.dataset.Dataset(params), 其中params包括:
    a. data_source: 为tfrecord文件地址
    b. reader: 一般设置为tf.TFRecordReader阅读器
    c. decoder: 为第一步设置的decoder
    d. num_samples: 样本数量
    e. items_to_description: 对样本及标签的描述
    f. num_classes: 分类的数量
    (3) 第三步, 我们设置provider = slim.dataset_data_provider.DatasetDataProvider(params), 其中params包括 :
    a. dataset: 第二步骤我们生成的数据集
    b. num_reader: 并行阅读器数量
    c. shuffle: 是否打乱
    d. num_epochs:每个数据源被读取的次数,如果设为None数据将会被无限循环的读取
    e. common_queue_capacity:读取数据队列的容量,默认为256
    f. scope:范围
    g. common_queue_min:读取数据队列的最小容量。
    (4) 第四步, 我们可以通过provider.get得到我们需要的数据了。

对不同图片大小的TFRecord读取并resize成相同大小

reshape_same_size函数 来对图片进行resize,这样我们可以对我们的图片进行batch操作了,因为有的神经网络训练需要一个batch一个batch操作,不同大小的图片在组成一个batch的时候会报错,因此我们我通过后期处理可以更好的对图片进行batch操作。
或者直接通过resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[FLAG.resize_height, FLAG.resize_width]))即可。

更多请参考

https://www.jianshu.com/p/b480e5fcb638

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

TEAM-AG

编程公园:输出是最好的学习方式

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值