从代码角度深入浅出图神经网络系列笔记(三)

本文详细介绍了PyTorch Geometric中Dataset类的使用,包括如何构建分批加载的大数据集,对比InMemoryDataset的区别,以及Mini-Batching的实现方式。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

这一节笔记中主要针对继承Dataset,分次加载到内存,这种数据集一般很大,不适合一次性加载完毕,需要分批加载处理。

构建数据集

1、Dataset

pytorch geometric 构建数据集分两种:
1、继承InMemoryDataset,一次性加载所有的数据到内存
2、继承Dataset,分次加载到内存

Mini-Batching:将一组样本组合成一个统一的表示形式,进行并行处理

2、官方文档例子

在这里插入图片描述
首先还是看下引入的库文件,对比一下InMemoryDataset,这里我们引入的是Dataset,对比一下这两个库,初始化的参数完全一致
在这里插入图片描述
主要是Dataset多了len()get()

  • torch_geometric.data.Dataset.len(): 返回数据集中的样本数

  • torch_geometric.data.Dataset.get(): 实现加载单个图的逻辑

下面来看对比分析:

import os.path as osp # 调用系统路径

import torch
from torch_geometric.data import Dataset


class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)# 对比InMemoryDataset 这块少了直接加载的代码,因为我们需要分次加载

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self): # 对比一次性加载,这里会有多个
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.

    def process(self):
        i = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
            i += 1

    def len(self): # 返回数目
        return len(self.processed_file_names)

    def get(self, idx): # 一个一个文件手动加载到内存中
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data

3、process解读

下面是我从某博客中引用的一段话,个人觉得说的挺好的。process()方法存在的意义是:

  1. 原始的格式可能是 csv 或者 mat,在process()函数里可以转化为 pt 格式的文件。
  2. 这样在get()方法中,就可以直接使用torch.load()函数,读取 pt 格式的文件,返回的是torch_geometric.data.Data类型的数据。
  3. 而不用在get()方法里面做数据转换操作,比如说,把其他格式的数据转换为 torch_geometric.data.Data类型的数据。
  4. 当然我们也可以提前把数据转换为 torch_geometric.data.Data类型,使用 pt 格式保存在self.processed_dir中。

Ps:上面这段话部分针对的是Dataset,其实InMemoryDataset也差不多,只不过最后不需要逐个加载进内存,而是直接加载进内存

MINI-BATCHING

官方文档地址:

https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html#pairs-of-graphs

我觉得这里我没有理解,视频的解释实在是太少了,之后再补

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值