torch.utils.data.DataLoader与enumerate联用发现

文章探讨了在PyTorch中使用torch.utils.data.DataLoader对数据集进行批处理,并结合enumerate遍历训练数据时,发现在不同的epoch中,即使在同一step,取出的训练数据也会变化。这表明DataLoader在联用enumerate时,取样过程是随机的,尤其是在设置了shuffle=True的情况下。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、发现说明

        使用torch.utils.data.DataLoader()函数对数据集进行按批分割处理,然后在训练网络时用enumerate()函数取出训练数据。发现不同Epoch,相同step(下文解释)情况下,enumerate取到的训练数据不一样。这样意味着DataLoader()联用enumerate()使用时候,每次取到的训练数据是动态的。

二、代码演示

import torch
import torch.utils.data as Data
import argparse
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms, datasets
import torch.backends.cudnn as cudnn

torch.set_printoptions(threshold=np.inf)   #tensor输出不省略

device = 'cuda' if torch.cuda.is_available() else 'cpu'

parser = argparse.ArgumentParser()
parser.add_argument('--lr', default=0.00005, type=float, help='learning rate')
parser.add_argument('--dataset', default='CIFAR10', type=str, choices=['MNIST', 'FashionMNIST', 'CIFAR10'])
parser.add_argument('--batchsize', default=128, type=int)
parser.add_argument('--epoch', default=2, type=int)
args = parser.parse_args()
batch_size = args.batchsize


#定义数据loader函数
def get_data_loader():
    if(args.dataset=='MNIST'):
        data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.1307], std=[0.3081])
        ])
        trainset = datasets.MNIST(root='.', train=True, transform=data_transform,
                                         download=True)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                                   shuffle=True, num_workers=2)
        testset = datasets.MNIST(root='.', train=False, transform=data_transform)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                                  shuffle=True, num_workers=2)
    if (args.dataset == 'FashionMNIST'):
        data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.1307], std=[0.3081])
        ])
        trainset = datasets.FashionMNIST(root='.', train=True, transform=data_transform,
                                  download=True)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                                   shuffle=True, num_workers=2)
        testset = datasets.FashionMNIST(root='.', train=False, transform=data_transform)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                                  shuffle=True, num_workers=2)
    if (args.dataset == 'CIFAR10'):
        data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        trainset = datasets.CIFAR10(root='.', train=True, transform=data_transform,
                                  download=True)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                                   shuffle=True, num_workers=2)
        testset = datasets.CIFAR10(root='.', train=False, transform=data_transform)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                                  shuffle=True, num_workers=2)
    return train_loader, test_loader


if __name__ == '__main__':
    Train_loader, Test_loader = get_data_loader()
    for e in range(args.epoch):
        print("Epoch: ", e + 1)
        for step, (x, y) in enumerate(Train_loader, 0):
            if device == 'cuda':
                input, label = x.cuda(), y.cuda()
            print(f'step:{step}')
            print(f'样本训练集\t input:{input.shape},  label:{label.shape}')
            print(input[0][0][0], label[0])
            break       #用break跳出,只对比每个epoch相同step下,enumerate取到的训练数据
        print('\n')

我们先执行了数据装载代码,主观意识上,此时训练集中每个batch的顺序是静态的,

Train_loader, Test_loader = get_data_loader()

然后通过enumerate()去取个训练集中每个batch的数据,应该不同的Epoch在step相同时,取到的数据应该相同。

但是,从结果看,显然,第1,2个Epoch在step0相同时,取到的训练数据是不同的,也就意味着DataLoader()联用enumerate()的时候,取样本的过程是动态的。

<think>我们正在处理一个关于PyTorch DataLoader对象的问题。用户提供了一个内存地址,但通常我们不会直接使用内存地址来获取对象信息。在Python中,我们可以通过变量名来访问对象。如果用户是在调试环境中(如使用pdb或ipdb),他们可能想检查这个内存地址对应的对象。然而,在正常的PyTorch使用中,我们通常不会直接操作内存地址。因此,我们需要引导用户提供更多上下文,或者解释如何检查DataLoader对象。由于用户提到了内存地址,我们可以假设用户可能在调试器中看到了这个对象,并希望了解它的详细信息。在调试器中,我们可以使用对象的内存地址来打印其信息,但更常见的是通过变量名。我们可以尝试以下方法:1.如果用户有一个变量指向该DataLoader,那么直接打印该变量即可。2.如果没有变量,但在调试器中,可以通过内存地址访问对象(这需要特定的调试器支持,并且通常不推荐,因为内存地址可能变化)。实际上,在Python中,我们可以使用`gc`模块来通过内存地址获取对象,但这不是常规做法,而且可能不安全。因此,我们建议用户通过变量名来访问DataLoader对象,然后我们可以展示如何获取DataLoader的详细信息。下面,我们将介绍如何检查PyTorchDataLoader对象,包括其属性、状态等。注意:由于我们无法直接访问用户的内存地址,我们将提供通用的方法。</think>### PyTorch DataLoader 对象详解 当您看到类似 `PyTorch DataLoader object details at memory address 0x00000176B04F9A88` 的信息时,这表示您正在处理一个内存中的 DataLoader 实例。以下是关键属性和调试方法: #### 核心属性解析 1. **数据集 (dataset)** DataLoader 的核心数据源,可通过 `loader.dataset` 访问: ```python print(data_loader.dataset) # 显示数据集类型和大小 ``` 常见类型:`TensorDataset`, `ImageFolder` 等 2. **批处理参数** - `batch_size`:每批样本数量,例如 $bs=32$ - `shuffle`:是否打乱数据顺序(布尔值) - `num_workers`:数据加载子进程数 3. **采样器 (sampler)** 控制数据索引的采样逻辑: ```python print(data_loader.sampler) # 如 RandomSampler, SequentialSampler ``` 4. **批次格式** 通过迭代获取批次结构: ```python batch = next(iter(data_loader)) print(type(batch)) # 通常为元组 (inputs, labels) ``` #### 调试方法 1. **检查配置参数**: ```python print(f"Batch size: {data_loader.batch_size}") print(f"Workers: {data_loader.num_workers}") print(f"Shuffle: {data_loader.shuffle}") ``` 2. **验证数据流**: ```python for i, (inputs, targets) in enumerate(data_loader): print(f"Batch {i}: inputs shape={inputs.shape}, targets={targets[:5]}") if i > 2: break # 仅检查前3个批次 ``` 3. **内存地址诊断**(在 IPython/Jupyter 中): ```python import torch %whos # 列出所有变量及内存地址 ``` #### 常见问题排查 - **内存地址变动**:Python 对象内存地址在每次运行时变化,这是正常现象 - **DataLoader 未初始化**:确保已执行 `iter(data_loader)` - **多进程问题**:当 $num\\_workers > 0$ 时,调试需使用 `if __name__ == '__main__':` 包裹代码 > 最佳实践:始终通过变量名而非内存地址操作对象[^1]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值