【Python深度学习系列】一文搞明白PyTorch数据处理与加载(案例+源码)

这是我的第409篇原创文章。

一、引言

在深度学习任务中,数据加载和处理是至关重要的一环。

PyTorch 提供了强大的数据加载和处理工具,主要包括:

  • torch.utils.data.Dataset:数据集的抽象类,需要自定义并实现 __len__(数据集大小)和 __getitem__(按索引获取样本)。

  • torch.utils.data.TensorDataset:基于张量的数据集,适合处理数据-标签对,直接支持批处理和迭代。

  • torch.utils.data.DataLoader:封装 Dataset 的迭代器,提供批处理、数据打乱、多线程加载等功能,便于数据输入模型训练。

  • torchvision.datasets.ImageFolder:从文件夹加载图像数据,每个子文件夹代表一个类别,适用于图像分类任务。

二、实现过程

2.1 torch.utils.data.Dataset

Dataset 是 PyTorch 中用于数据集抽象的类。自定义数据集需要继承 torch.utils.data.Dataset 并重写以下两个方法:

  • __len__:返回数据集的大小。

  • __getitem__:按索引获取一个数据样本及其标签。

代码:

import torch
from torch.utils.data import Dataset

# 自定义数据集
class MyDataset(Dataset):
    def __init__(self, data, labels):
        # 数据初始化
        self.data = data
        self.labels = labels

    def __len__(self):
        # 返回数据集大小
        return len(self.data)

    def __getitem__(self, idx):
        # 按索引返回数据和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 生成示例数据
data = torch.randn(100, 5)  # 100 个样本,每个样本有 5 个特征
labels = torch.randint(0, 2, (100,))  # 100 个标签,取值为 0 或 1

# 实例化数据集
dataset = MyDataset(data, labels)

# 测试数据集
print("数据集大小:", len(dataset))
print("第 0 个样本:", dataset[0])

输出结果:

图片

2.2 torch.utils.data.DataLoader

DataLoader 是 PyTorch 提供的数据加载器,用于批量加载数据集。提供了以下功能:

  • 批量加载:通过设置 batch_size

  • 数据打乱:通过设置 shuffle=True

  • 多线程加速:通过设置 num_workers

  • 迭代访问:方便地按批次访问数据。

代码:

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# 自定义数据集
class MyDataset(Dataset):
    def __init__(self, data, labels):
        # 数据初始化
        self.data = data
        self.labels = labels

    def __len__(self):
        # 返回数据集大小
        return len(self.data)

    def __getitem__(self, idx):
        # 按索引返回数据和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 生成示例数据
data = torch.randn(100, 5)  # 100 个样本,每个样本有 5 个特征
labels = torch.randint(0, 2, (100,))  # 100 个标签,取值为 0 或 1

# 实例化数据集
dataset = MyDataset(data, labels)
# 实例化 DataLoader
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0)

# 遍历 DataLoader
for batch_idx, (batch_data, batch_labels) in enumerate(dataloader):
    print(f"批次 {batch_idx + 1}")
    print("数据:", batch_data)
    print("标签:", batch_labels)
    if batch_idx == 2:  # 仅显示前 3 个批次
        break

输出结果:

图片

每次循环中,DataLoader 会返回一个批次的数据,包括输入特征(inputs)和目标标签(labels)。

2.3 Dataset 与 DataLoader 的自定义应用

以下是一个将 CSV 文件 作为数据源,并通过自定义 Dataset 和 DataLoader 读取数据。

图片

代码:

import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader

# 自定义 CSV 数据集
class CSVDataset(Dataset):
    def __init__(self, file_path):
        # 读取 CSV 文件
        self.data = pd.read_csv(file_path)

    def __len__(self):
        # 返回数据集大小
        return len(self.data)

    def __getitem__(self, idx):
        # 使用 .iloc 明确基于位置索引
        row = self.data.iloc[idx]
        # 将特征和标签分开
        features = torch.tensor(row.iloc[:-1].to_numpy(), dtype=torch.float32)  # 特征
        label = torch.tensor(row.iloc[-1], dtype=torch.float32)  # 标签
        return features, label

# 实例化数据集和 DataLoader
dataset = CSVDataset("runoob_pytorch_data.csv")
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# 遍历 DataLoader
for features, label in dataloader:
    print("特征:", features)
    print("标签:", label)
    break

输出结果:

图片

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

数据杂坛

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值