TensorFlow 深度学习 | Dataset API 数据读取详解

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖

在这里插入图片描述

本博客的精华专栏:
【自动化测试】 【测试经验】 【人工智能】 【Python】


在这里插入图片描述

TensorFlow 深度学习 | Dataset API 数据读取详解

在深度学习任务中,高效的数据输入管道至关重要。TensorFlow 提供的 tf.data.Dataset API,可以帮助我们从多种数据源(如内存数组、CSV 文件、图片数据、TFRecord 等)高效读取、处理并批量提供数据,极大提升训练效率。本文将带你深入理解 Dataset API 的用法与核心思想。


📌 一、为什么要使用 Dataset API?

在深度学习模型训练中,常见的数据加载方式有:

  • 一次性加载到内存:适合小规模数据,但当数据集较大时会导致内存不足。
  • 逐批次手动加载:代码繁琐,不利于扩展与优化。

Dataset API 的优势在于:

  1. 高效性:支持流水线式的数据预处理(如 shuffle、batch、map)。
  2. 灵活性:可轻松适配不同数据源(numpy 数组、CSV 文件、TFRecord 等)。
  3. 可扩展性:能处理海量数据,支持分布式训练的数据输入。

📂 二、Dataset API 的核心概念

在正式使用之前,我们需要理解 Dataset API 的三个关键角色:

1. Dataset

数据集的抽象表示,可以来自 内存数据、文件、生成器等。

2. Iterator

迭代器,用于遍历 Dataset。

3. Transformation

对数据进行的操作,例如:

  • map(fn):对每个元素应用函数
  • batch(size):按批次组合
  • shuffle(buffer_size):打乱数据顺序
  • repeat(count):重复数据集

🔢 三、Dataset API 基本用法

下面以 Numpy 数据 为例,展示 Dataset API 的基本流程。

1. 从 Numpy 数组创建 Dataset

import tensorflow as tf
import numpy as np

# 假设我们有输入数据 X 和标签 y
X = np.array([[1, 2], [3, 4], [5, 6]])
y = np.array([0, 1, 0])

# 将 Numpy 数据转换为 Dataset
dataset = tf.data.Dataset.from_tensor_slices((X, y))

for features, label in dataset:
    print(features.numpy(), label.numpy())

2. 批处理与打乱

# 批量处理 + 打乱数据
dataset = dataset.shuffle(buffer_size=3).batch(2)

for batch_x, batch_y in dataset
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Thomas Kant

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值