学习Pytorch+Python之自定义数据集(创建CSV文件)

1、CSV文件
逗号分隔值(Comma-Separated Values,CSV,有时也称为字符分隔值,因为分隔字符也可以不是逗号),其文件以纯文本形式存储表格数据(数字和文本)。文件格式如下图,也可用Excel打开。
在这里插入图片描述在这里插入图片描述
2、简单分类数据集
本文数据集是一个蚂蚁(ants)和蜜蜂(bees)的简单的二分类数据,分别存放于dataset文件夹下的ants_image和bees_image文件夹下,用文件夹名字作为标签。
3、创建CSV文件代码

# 导入相关库
import csv
import random
import os
import glob
# 数据地址,注意是分类文件夹的上层文件夹
# dataset文件夹中的两个子文件夹分别对应两个类名
root = 'dataset'
# 对数据地址中的文件夹进行遍历,将类名存放于列表names中
names = os.listdir(root)
# 创建名为images的空列表用于存放图像地址
images = []
# 创建名称、标签字典,用于存放二分类的类名和标签
names_labels 
<think>在PyTorch创建自定义数据集类需要继承`torch.utils.data.Dataset`并实现两个关键方法:`__len__`和`__getitem__`。 1. `__len__`方法应返回数据集的大小(样本数)。 2. `__getitem__`方法根据索引返回一个样本(数据和标签)。 以下是一个创建自定义数据集类的步骤和示例代码: 步骤: a. 导入必要的模块:`from torch.utils.data import Dataset` b. 定义类并继承`Dataset` c. 在`__init__`方法中初始化数据(如文件路径、数据列表等) d. 实现`__len__`方法 e. 实现`__getitem__`方法,返回单个样本(例如,一个图像和其标签) 示例场景:假设我们有一个包含图像路径和标签CSV文件,我们将创建一个数据集类来加载这些图像和标签。 代码示例:</think>### PyTorch创建自定义数据集类的方法 在PyTorch创建自定义数据集类需要继承`torch.utils.data.Dataset`类并实现三个核心方法:`__init__`、`__len__`和`__getitem__`。以下是详细步骤和代码示例: #### 1. 基本实现方法 ```python import torch from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data, labels, transform=None): """ 初始化数据集 :param data: 输入数据(如图像路径列表或特征张量) :param labels: 对应标签 :param transform: 数据预处理/增强函数 """ self.data = data self.labels = labels self.transform = transform def __len__(self): """返回数据集大小""" return len(self.data) def __getitem__(self, idx): """ 获取单个样本 :param idx: 样本索引 :return: (样本数据, 标签) """ sample = self.data[idx] label = self.labels[idx] # 应用数据转换(如果有) if self.transform: sample = self.transform(sample) return sample, label ``` #### 2. 完整示例:图像数据集 ```python import os import torch from PIL import Image from torch.utils.data import Dataset from torchvision import transforms class ImageDataset(Dataset): def __init__(self, root_dir, transform=None): """ :param root_dir: 包含图像和label.txt的目录 :param transform: 数据预处理管道 """ self.root_dir = root_dir self.transform = transform self.image_paths = [] self.labels = [] # 从label.txt加载数据路径和标签 with open(os.path.join(root_dir, 'labels.txt'), 'r') as f: for line in f: img_path, label = line.strip().split(',') self.image_paths.append(os.path.join(root_dir, img_path)) self.labels.append(int(label)) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] image = Image.open(img_path).convert('RGB') # 确保RGB格式 label = self.labels[idx] if self.transform: image = self.transform(image) return image, label # 使用示例 if __name__ == "__main__": # 定义数据预处理流程 transform = transforms.Compose([ transforms.Resize((256, 256)), # 调整大小 transforms.RandomHorizontalFlip(), # 数据增强:随机水平翻转 transforms.ToTensor(), # 转为Tensor transforms.Normalize( # 标准化 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # 创建数据集实例 dataset = ImageDataset( root_dir='path/to/your/images', transform=transform ) # 创建DataLoader dataloader = torch.utils.data.DataLoader( dataset, batch_size=32, shuffle=True, num_workers=4 ) # 测试数据加载 for images, labels in dataloader: print(f"批处理图像形状: {images.shape}") print(f"批处理标签形状: {labels.shape}") break ``` #### 3. 关键实现要点解释 1. **`__init__`方法**: - 初始化数据路径、标签和转换函数 - 通常在此处加载元数据(如文件路径列表) - 避免在此处加载所有数据到内存(惰性加载) 2. **`__len__`方法**: - 必须返回数据集的总样本数 - DataLoader依赖此方法确定迭代次数 3. **`__getitem__`方法**: - 核心方法,根据索引返回单个样本 - 应包含数据加载和预处理逻辑 - 支持索引访问:`dataset[i]` 获取第i个样本 - 建议在此方法中实现惰性加载(需要时再读取数据) 4. **数据转换(transform)**: - 推荐使用`torchvision.transforms`模块 - 训练和测试时可使用不同的转换管道 - 包含数据标准化、增强等操作 #### 4. 高级技巧 **处理大型数据集**: ```python class LargeDataset(Dataset): def __init__(self, hdf5_path): import h5py self.file = h5py.File(hdf5_path, 'r') # 打开但不加载所有数据 self.data = self.file['images'] self.labels = self.file['labels'] def __len__(self): return len(self.data) def __getitem__(self, idx): # 按需从磁盘加载 return self.data[idx], self.labels[idx] def __del__(self): self.file.close() # 确保关闭文件 ``` **多模态数据**: ```python class MultiModalDataset(Dataset): def __getitem__(self, idx): # 返回多种类型数据 return { 'image': image_tensor, 'text': text_embedding, 'audio': audio_features, 'label': label } ```
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值