dataset
对于单个样本
dataloader
批量样本
Dataset 存储样本和它们相应的标签,DataLoader 在 Dataset 基础上添加了一个迭代器,迭代器可以迭代数据集,以便能够轻松地访问 Dataset 中的样本(变为mini-batch形式,多个样本组合成mini-batch,random,保存在gpu中等)
支持下标索引获取样本,主要是拿出mini-batch(一组数据,训练时使用)
batch可以利用并行计算(向量计算),随机梯度下降每次一个样本时间过长(虽然随机性较好)
均衡性能和时间:mini-batch epoch:一个epoch中所有样本都参与了训练 batch-size:每次训练时用到的样本数量 iteration:batch的个数
如果可以使用下标获取dataset样本和知道dataset长度,则DataLoader可以自动生成mini-batch数据集
pytorch还提供部分预加载数据集
torch.utils.data.Dataset
https://pytorch.org/text/stable/datasets.html
构造数据集(init()函数中)的两种选择
(1)在init中加载整个数据集,用getitem时将第[i]个样本传出去
(2)数据集较大的情况:比如图片:可能只是加载图片的路径列表,图像分割时可能输出y也很大,则也使用文件名
防止显存超出
使用torchvision导入预加载的数据集
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
自定义数据集加载
自定义的数据集类必须实现三个函数: init, len, 和 getitem
getitem 从数据集中给定的索引 idx 处加载并返回一个样本
函数 len 返回我们数据集中的样本数。
# import gzip
# import numpy as np
# import os
# import requests
# # 下载 Fashion MNIST 数据集
# def download_fashion_mnist(base_url, filename, destination):
# if not os.path.exists(destination):
# os.makedirs(destination)
# filepath = os.path.join(destination, filename)
# if not os.path.exists(filepath):
# url = base_url + filename
# response = requests.get(url)
# with open(filepath, 'wb') as f:
# f.write(response.content)
# return filepath
# # 解压 gz 文件
# def extract_gz(filepath, destination):
# with gzip.open(filepath, 'rb') as f_in:
# with open(destination, 'wb') as f_out:
# f_out.write(f_in.read())
# return destination
# # 读取二进制文件并转换为 numpy 数组
# def load_fashion_mnist_images(filepath):
# with open(filepath, 'rb') as f:
# data = f.read()
# images = np.frombuffer(data, dtype=np.uint8, offset=16).reshape(-1, 28, 28)
# return images
# def load_fashion_mnist_labels(filepath):
# with open(filepath, 'rb') as f:
# data = f.read()
# labels = np.frombuffer(data, dtype=np.uint8, offset=8)
# return labels
# base_url = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
# destination_dir = 'fashion_mnist'
# train_images_path = download_fashion_mnist(base_url, 'train-images-idx3-ubyte.gz', destination_dir)
# train_labels_path = download_fashion_mnist(base_url, 'train-labels-idx1-ubyte.gz', destination_dir)
# test_images_path = download_fashion_mnist(base_url, 't10k-images-idx3-ubyte.gz', destination_dir)
# test_labels_path = download_fashion_mnist(base_url, 't10k-labels-idx1-ubyte.gz', destination_dir)
# train_images = load_fashion_mnist_images(extract_gz(train_images_path, 'train-images-idx3-ubyte'))
# train_labels = load_fashion_mnist_labels(extract_gz(train_labels_path, 'train-labels-idx1-ubyte'))
# test_images = load_fashion_mnist_images(extract_gz(test_images_path, 't10k-images-idx3-ubyte'))
# test_labels = load_fashion_mnist_labels(extract_gz(test_labels_path, 't10k-labels-idx1-ubyte'))
# print(f'Train images shape: {train_images.shape}')
# print(f'Train labels shape: {train_labels.shape}')
# print(f'Test images shape: {test_images.shape}')
# print(f'Test labels shape: {test_labels.shape}')
# print(train_labels)
import gzip
import numpy as np
import os
import requests
# 下载 Fashion MNIST 数据集
def download_fashion_mnist(base_url, filename, destination):
if not os.path.exists(destination):
os.makedirs(destination)
filepath = os.path.join(destination, filename)
if not os.path.exists(filepath):
url = base_url + filename
response = requests.get(url)
with open(filepath, 'wb') as f:
f.write(response.content)
return filepath
# 解压 gz 文件
def extract_gz(filepath, destination):
with gzip.open(filepath, 'rb') as f_in:
with open(destination, 'wb') as f_out:
f_out.write(f_in.read())
return destination
# 读取二进制文件并转换为 numpy 数组
def load_fashion_mnist_images(filepath):
with open(filepath, 'rb') as f:
data = f.read()
images = np.frombuffer(data, dtype=np.uint8, offset=16).reshape(-1, 28, 28)
return images
def load_fashion_mnist_labels(filepath):
with open(filepath, 'rb') as f:
data = f.read()
labels = np.frombuffer(data, dtype=np.uint8, offset=8)
return labels
# 保存 numpy 数组到文件
def save_numpy_array(data, filepath):
np.save(filepath, data)
# 加载保存的 numpy 数组
def