疏锦行Python打卡 DAY 38 Dataset和Dataloader类

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt

# 设置随机种子,确保结果可复现
torch.manual_seed(42)

transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量并归一化到[0,1]
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)

import matplotlib.pyplot as plt

# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签

# 示例代码
class MyList:
    def __init__(self):
        self.data = [10, 20, 30, 40, 50]

    def __getitem__(self, idx):
        return self.data[idx]

# 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2])  # 输出:30

class MyList:
    def __init__(self):
        self.data = [10, 20, 30, 40, 50]

    def __len__(self):
        return len(self.data)

# 创建类的实例
my_list_obj = MyList()
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj))  # 输出:5

# minist数据集的简化版本
class MNIST(Dataset):
    def __init__(self, root, train=True, transform=None):
        # 初始化:加载图片路径和标签
        self.data, self.targets = fetch_mnist_data(root, train) # 这里假设 fetch_mnist_data 是一个函数,用于加载 MNIST 数据集的图片路径和标签
        self.transform = transform # 预处理操作
        
    def __len__(self): 
        return len(self.data)  # 返回样本总数
    
    def __getitem__(self, idx): # 获取指定索引的样本
        # 获取指定索引的图像和标签
        img, target = self.data[idx], self.targets[idx]
        
        # 应用图像预处理(如ToTensor、Normalize)
        if self.transform is not None: # 如果有预处理操作
            img = self.transform(img) # 转换图像格式
        # 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化
            
        return img, target  # 返回处理后的图像和标签

# 可视化原始图像(需要反归一化)
def imshow(img):
    img = img * 0.3081 + 0.1307  # 反标准化
    npimg = img.numpy()
    plt.imshow(npimg[0], cmap='gray') # 显示灰度图像
    plt.show()

print(f"Label: {label}")
imshow(image)

# 3. 创建数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
    shuffle=True # 随机打乱数据
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1000 # 每个批次1000张图片
    # shuffle=False # 测试时不需要打乱数据
)

打卡:@浙大疏锦行

BP神经网络是一种经典的人工神经网络模型,用于解决分回归问题。在Python中,我们可以使用一些库来构建训练BP神经网络,例如PyTorch、Keras或TensorFlow。 对于BP神经网络,我们首先需要准备一个数据集。数据集应该包含输入特征对应的标签。特征可以是数字、文本或图像等形式,而标签则表示我们要预测的目标。通常,我们将数据集分为训练集测试集,用于训练评估模型的性能。 在Python中,我们可以使用PyTorch库来处理数据集加载器。PyTorch提供了一个名为`torch.utils.data.Dataset`的,用于定义自定义数据集。我们可以继承这个并实现`__len__``__getitem__`方法来获取数据集的长度索引对应的样本。 下一步是使用`torch.utils.data.DataLoader`来创建数据加载器。数据加载器可以帮助我们以批量的方式加载数据,并提供多线程处理数据打乱等功能。我们可以设置批量大小、是否打乱数据以及使用多个线程来加载数据。 以下是一个示例代码,展示了如何准备数据集数据加载器: ```python import torch from torch.utils.data import Dataset, DataLoader # 定义自定义数据集 class MyDataset(Dataset): def __init__(self, data, targets): self.data = data self.targets = targets def __len__(self): return len(self.data) def __getitem__(self, index): x = self.data[index] y = self.targets[index] return x, y # 准备数据 data = [...] # 特征数据 targets = [...] # 标签数据 # 创建数据集 dataset = MyDataset(data, targets) # 创建数据加载器 batch_size = 32 shuffle = True num_workers = 4 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) ``` 在上面的示例中,`MyDataset`是一个自定义数据集,其中的`__len__`方法返回数据集的长度,`__getitem__`方法根据索引返回对应的特征标签。然后,我们将数据集传递给`DataLoader`来创建数据加载器,并设置了批量大小为32,打乱数据并使用4个线程进行加载。 通过使用数据集数据加载器,我们可以方便地准备加载数据,以供BP神经网络进行训练评估。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值