重学PyTorch,粗略笔记(二)dataset,dataloader

dataset

对于单个样本

dataloader

批量样本

Dataset 存储样本和它们相应的标签,DataLoader 在 Dataset 基础上添加了一个迭代器,迭代器可以迭代数据集,以便能够轻松地访问 Dataset 中的样本(变为mini-batch形式,多个样本组合成mini-batch,random,保存在gpu中等)

支持下标索引获取样本,主要是拿出mini-batch(一组数据,训练时使用)

batch可以利用并行计算(向量计算),随机梯度下降每次一个样本时间过长(虽然随机性较好)

均衡性能和时间:mini-batch epoch:一个epoch中所有样本都参与了训练 batch-size:每次训练时用到的样本数量 iteration:batch的个数

如果可以使用下标获取dataset样本和知道dataset长度,则DataLoader可以自动生成mini-batch数据集

pytorch还提供部分预加载数据集
torch.utils.data.Dataset
https://pytorch.org/text/stable/datasets.html
在这里插入图片描述

在这里插入图片描述

构造数据集(init()函数中)的两种选择

(1)在init中加载整个数据集,用getitem时将第[i]个样本传出去

(2)数据集较大的情况:比如图片:可能只是加载图片的路径列表,图像分割时可能输出y也很大,则也使用文件名

防止显存超出

使用torchvision导入预加载的数据集

在这里插入图片描述

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

自定义数据集加载

自定义的数据集类必须实现三个函数: init, len, 和 getitem
getitem 从数据集中给定的索引 idx 处加载并返回一个样本
函数 len 返回我们数据集中的样本数。

# import gzip
# import numpy as np
# import os
# import requests

# # 下载 Fashion MNIST 数据集
# def download_fashion_mnist(base_url, filename, destination):
#     if not os.path.exists(destination):
#         os.makedirs(destination)
#     filepath = os.path.join(destination, filename)
#     if not os.path.exists(filepath):
#         url = base_url + filename
#         response = requests.get(url)
#         with open(filepath, 'wb') as f:
#             f.write(response.content)
#     return filepath

# # 解压 gz 文件
# def extract_gz(filepath, destination):
#     with gzip.open(filepath, 'rb') as f_in:
#         with open(destination, 'wb') as f_out:
#             f_out.write(f_in.read())
#     return destination

# # 读取二进制文件并转换为 numpy 数组
# def load_fashion_mnist_images(filepath):
#     with open(filepath, 'rb') as f:
#         data = f.read()
#         images = np.frombuffer(data, dtype=np.uint8, offset=16).reshape(-1, 28, 28)
#     return images

# def load_fashion_mnist_labels(filepath):
#     with open(filepath, 'rb') as f:
#         data = f.read()
#         labels = np.frombuffer(data, dtype=np.uint8, offset=8)
#     return labels

# base_url = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
# destination_dir = 'fashion_mnist'

# train_images_path = download_fashion_mnist(base_url, 'train-images-idx3-ubyte.gz', destination_dir)
# train_labels_path = download_fashion_mnist(base_url, 'train-labels-idx1-ubyte.gz', destination_dir)
# test_images_path = download_fashion_mnist(base_url, 't10k-images-idx3-ubyte.gz', destination_dir)
# test_labels_path = download_fashion_mnist(base_url, 't10k-labels-idx1-ubyte.gz', destination_dir)

# train_images = load_fashion_mnist_images(extract_gz(train_images_path, 'train-images-idx3-ubyte'))
# train_labels = load_fashion_mnist_labels(extract_gz(train_labels_path, 'train-labels-idx1-ubyte'))
# test_images = load_fashion_mnist_images(extract_gz(test_images_path, 't10k-images-idx3-ubyte'))
# test_labels = load_fashion_mnist_labels(extract_gz(test_labels_path, 't10k-labels-idx1-ubyte'))

# print(f'Train images shape: {train_images.shape}')
# print(f'Train labels shape: {train_labels.shape}')
# print(f'Test images shape: {test_images.shape}')
# print(f'Test labels shape: {test_labels.shape}')
# print(train_labels)


import gzip
import numpy as np
import os
import requests

# 下载 Fashion MNIST 数据集
def download_fashion_mnist(base_url, filename, destination):
    if not os.path.exists(destination):
        os.makedirs(destination)
    filepath = os.path.join(destination, filename)
    if not os.path.exists(filepath):
        url = base_url + filename
        response = requests.get(url)
        with open(filepath, 'wb') as f:
            f.write(response.content)
    return filepath

# 解压 gz 文件
def extract_gz(filepath, destination):
    with gzip.open(filepath, 'rb') as f_in:
        with open(destination, 'wb') as f_out:
            f_out.write(f_in.read())
    return destination

# 读取二进制文件并转换为 numpy 数组
def load_fashion_mnist_images(filepath):
    with open(filepath, 'rb') as f:
        data = f.read()
        images = np.frombuffer(data, dtype=np.uint8, offset=16).reshape(-1, 28, 28)
    return images

def load_fashion_mnist_labels(filepath):
    with open(filepath, 'rb') as f:
        data = f.read()
        labels = np.frombuffer(data, dtype=np.uint8, offset=8)
    return labels

# 保存 numpy 数组到文件
def save_numpy_array(data, filepath):
    np.save(filepath, data)

# 加载保存的 numpy 数组
def 
### 回答1: DatasetDataLoaderPyTorch 中用于加载和处理数据的两个主要组件。Dataset 用于从数据源中提取和加载数据,DataLoader 则用于将数据转换为适合机器学习模型训练的格式。 ### 回答2: 在PyTorch中,DatasetDataLoader是用于处理和加载数据的两个重要类。 Dataset是一个抽象类,用于表示数据集对象。我们可以自定义Dataset子类来处理我们自己的数据集。通过继承Dataset类,我们需要实现两个主要方法: - __len__()方法:返回数据集的大小(样本数量) - __getitem__(idx)方法:返回索引为idx的样本数据 使用Dataset类的好处是可以统一处理训练集、验证集和测试集等不同的数据集,将数据进行一致的格式化和预处理。 DataLoader是一个实用工具,用于将Dataset对象加载成批量数据。数据加载器可以根据指定的批大小、是否混洗样本和多线程加载等选项来提供高效的数据加载方式。DataLoader是一个可迭代对象,每次迭代返回一个批次的数据。我们可以通过循环遍历DataLoader对象来获取数据。 使用DataLoader可以实现以下功能: - 数据批处理:将数据集划分为批次,并且可以指定每个批次的大小。 - 数据混洗:可以通过设置shuffle选项来随机打乱数据集,以便更好地训练模型。 - 并行加载:可以通过设置num_workers选项来指定使用多少个子进程来加载数据,加速数据加载过程。 综上所述,DatasetDataLoaderPyTorch中用于处理和加载数据的两个重要类。Dataset用于表示数据集对象,我们可以自定义Dataset子类来处理我们自己的数据集。而DataLoader是一个实用工具,用于将Dataset对象加载成批量数据,提供高效的数据加载方式,支持数据批处理、数据混洗和并行加载等功能。 ### 回答3: 在pytorch中,Dataset是一个用来表示数据的抽象类,它封装了数据集的访问方式和数据的获取方法。Dataset类提供了读取、处理和转换数据的功能,可以灵活地处理各种类型的数据集,包括图像、语音、文本等。用户可以继承Dataset类并实现自己的数据集类,根据实际需求定制数据集。 Dataloader是一个用来加载数据的迭代器,它通过Dataset对象来获取数据,并按照指定的batch size进行分批处理。Dataloader可以实现多线程并行加载数据,提高数据读取效率。在训练模型时,通常将Dataset对象传入Dataloader进行数据加载,并通过循环遍历Dataloader来获取每个batch的数据进行训练。 DatasetDataloader通常配合使用,Dataset用于数据的读取和预处理,Dataloader用于并行加载和分批处理数据。使用DatasetDataloader的好处是可以轻松地处理大规模数据集,实现高效的数据加载和预处理。此外,DatasetDataloader还提供了数据打乱、重复采样、数据划分等功能,可以灵活地控制数据的访问和使用。 总之,DatasetDataloaderpytorch中重要的数据处理模块,它们提供了方便的接口和功能,用于加载、处理和管理数据集,为模型训练和评估提供了便利。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值