💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖
本博客的精华专栏:
【自动化测试】 【测试经验】 【人工智能】 【Python】
TensorFlow 深度学习 | Dataset API 数据读取详解
在深度学习任务中,高效的数据输入管道至关重要。TensorFlow 提供的
tf.data.Dataset
API,可以帮助我们从多种数据源(如内存数组、CSV 文件、图片数据、TFRecord 等)高效读取、处理并批量提供数据,极大提升训练效率。本文将带你深入理解 Dataset API 的用法与核心思想。
📌 一、为什么要使用 Dataset API?
在深度学习模型训练中,常见的数据加载方式有:
- 一次性加载到内存:适合小规模数据,但当数据集较大时会导致内存不足。
- 逐批次手动加载:代码繁琐,不利于扩展与优化。
而 Dataset API 的优势在于:
- 高效性:支持流水线式的数据预处理(如 shuffle、batch、map)。
- 灵活性:可轻松适配不同数据源(numpy 数组、CSV 文件、TFRecord 等)。
- 可扩展性:能处理海量数据,支持分布式训练的数据输入。
📂 二、Dataset API 的核心概念
在正式使用之前,我们需要理解 Dataset API 的三个关键角色:
1. Dataset
数据集的抽象表示,可以来自 内存数据、文件、生成器等。
2. Iterator
迭代器,用于遍历 Dataset。
3. Transformation
对数据进行的操作,例如:
map(fn)
:对每个元素应用函数batch(size)
:按批次组合shuffle(buffer_size)
:打乱数据顺序repeat(count)
:重复数据集
🔢 三、Dataset API 基本用法
下面以 Numpy 数据 为例,展示 Dataset API 的基本流程。
1. 从 Numpy 数组创建 Dataset
import tensorflow as tf
import numpy as np
# 假设我们有输入数据 X 和标签 y
X = np.array([[1, 2], [3, 4], [5, 6]])
y = np.array([0, 1, 0])
# 将 Numpy 数据转换为 Dataset
dataset = tf.data.Dataset.from_tensor_slices((X, y))
for features, label in dataset:
print(features.numpy(), label.numpy())
2. 批处理与打乱
# 批量处理 + 打乱数据
dataset = dataset.shuffle(buffer_size=3).batch(2)
for batch_x, batch_y in dataset