从csv(excel)中读取数据集,自定义重写torch中的Dataset类,并使用DataLoader读取数据

该文章介绍如何在PyTorch中重写Dataset类来处理CSV数据,模拟数据缺失,并使用Dataloader进行批量读取。通过定义__init__,__len__和__getitem__方法,自定义的数据集可以从CSV文件加载数据,并应用掩模矩阵来模拟数据缺失。之后,利用Dataloader进行训练数据的遍历。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.从csv中读取数据集并重写Dataset类,模拟数据缺失

重写Dataset类需要重写__init__,__len__,__getitem__三个方法

__init__:参数初始化、路径定义等

__len__:返回数据集长度

__getitem__:根据索引获得对应的数据集,可同时返回data和label

代码

import os
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import random
import copy

class myDataSet(Dataset):
    def __init__(self, data_dir, mask, transform=None):
        """
        :param data_dir: 数据文件路径
        :param label_dir: 标签文件路径
        :param transform: transform操作
        """
        self.data_dir = data_dir
        self.transform = transform
        self.mask = mask # 掩膜矩阵,模拟数据缺失

        # 从csv文件中读取数据集
        data = pd.read_csv(self.data_dir, header=None)

        # 从csv文件中得到100*360组数据,每组数据为16个节点的12*5个特征
        self.data_tensor = torch.tensor(data.values) # (36000, 16*12*15)

        # 转化维度为样本数*节点*特征数
        self.data_tensor = self.data_tensor.view(100*360, 16, 12*5)

        # 拷贝label(完整数据)
        self.label = copy.copy(self.data_tensor)

        # 获得data(不完整数据)
        self.data_tensor = torch.mul(self.data_tensor, self.mask)

    def __len__(self):
        # 返回数据集长度
        return self.data_tensor.shape[0]

    def __getitem__(self, index):
        # 读取数据
        data = self.data_tensor[index,:,:]
        # 读取标签
        label = self.label[index,:,:]

        if self.transform:
            data = self.transform(data)
            label = self.transform(label)



        return data, label  # 返回数据和标签

2.使用Dataloader读取数据集

代码

import torch
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from mydataset import myDataSet # 前面代码中重写的mydataset类

# 读取掩膜矩阵
mask_dir = r'mask.csv'
mask = pd.read_csv(mask_dir, header=None)
mask_tensor = torch.tensor(mask.values)



# 读取数据集
data_dir = r"obs_csv.csv"
train_dataset = myDataSet(
    data_dir=data_dir,
    mask=mask_tensor
)

# 加载数据集
train_iter = DataLoader(train_dataset, batch_size=36, shuffle=False)

for batchsz, (data, label) in enumerate(train_iter):
    # i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
    print("第 {} 个Batch size of label {} and size of data{}".format(batchsz, data.shape, label.shape))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值