PyTorch Geometric Temporal:时空图神经网络入门指南
项目概述
PyTorch Geometric Temporal 是一个基于 PyTorch Geometric 的时空图神经网络扩展库,专注于处理时空信号数据。作为首个针对几何结构的时序深度学习开源库,它提供了在动态和静态图上实现恒定时间差图神经网络的能力。
核心概念
时空图数据特点
时空图数据同时包含空间维度和时间维度的信息:
- 空间维度:通过图结构表示实体间的关系
- 时间维度:数据随时间演变形成序列
数据迭代器类型
PyTorch Geometric Temporal 提供了三类数据迭代器,满足不同场景需求:
-
静态图时序信号(StaticGraphTemporalSignal)
- 图结构保持不变
- 节点/边特征随时间变化
- 典型应用:交通流量预测
-
动态图时序信号(DynamicGraphTemporalSignal)
- 图结构和节点/边特征都随时间变化
- 典型应用:社交网络演化分析
-
动态图静态信号(DynamicGraphStaticSignal)
- 图结构随时间变化
- 节点/边特征保持不变
- 典型应用:固定传感器网络监测
数据处理流程
数据快照结构
每个时间点的数据快照包含以下关键属性:
edge_index
:边索引张量,定义图结构edge_attr
:边特征张量,用于加权聚合x
:节点特征矩阵y
:预测目标值
批处理支持
对于需要同时处理多个图的情况,库提供了批处理迭代器版本,使用块对角矩阵技术实现高效批处理。
基准数据集
项目包含多个真实世界的基准数据集,方便模型评估:
新发布数据集
- 匈牙利水痘病例数据集:地区级流行病学数据
- PedalMe伦敦数据集:自行车共享数据
- 在线百科数学页面数据集:网页访问量数据
- 风力发电机输出数据集:能源生产数据
集成数据集
- Pems Bay交通数据集
- Metr LA交通数据集
- 英格兰疫情数据集
- Twitter网球数据集
典型应用案例
案例1:流行病预测
使用匈牙利水痘数据集预测地区水痘病例数:
# 数据准备
from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split
loader = ChickenpoxDatasetLoader()
dataset = loader.get_dataset()
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
# 模型定义
class RecurrentGCN(torch.nn.Module):
def __init__(self, node_features):
super().__init__()
self.recurrent = DCRNN(node_features, 32, 1)
self.linear = torch.nn.Linear(32, 1)
def forward(self, x, edge_index, edge_weight):
h = self.recurrent(x, edge_index, edge_weight)
h = F.relu(h)
return self.linear(h)
案例2:网页流量预测
使用在线百科数学页面数据集预测每日访问量:
# 数据准备
loader = WikiMathsDatasetLoader()
dataset = loader.get_dataset(lags=14)
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.5)
# 模型定义
class RecurrentGCN(torch.nn.Module):
def __init__(self, node_features, filters):
super().__init__()
self.recurrent = GConvGRU(node_features, filters, 2)
self.linear = torch.nn.Linear(filters, 1)
def forward(self, x, edge_index, edge_weight):
h = self.recurrent(x, edge_index, edge_weight)
h = F.relu(h)
return self.linear(h)
模型训练技巧
- 损失计算:可以累积多个时间步的损失后反向传播,也可以每个时间步单独反向传播
- 优化器选择:推荐使用Adam优化器,学习率通常设为0.01
- 正则化:可通过Dropout层或权重衰减防止过拟合
- 评估指标:回归任务常用MSE,分类任务可用准确率等
总结
PyTorch Geometric Temporal为时空图数据分析提供了强大工具,其主要优势包括:
- 统一的API设计,简化了时空图数据处理流程
- 丰富的预实现模型,涵盖多种时空图神经网络架构
- 多样的基准数据集,便于模型评估比较
- 与PyTorch生态无缝集成,支持GPU加速
对于需要处理时空图数据的任务,如交通预测、流行病建模、社交网络分析等,该库都是值得考虑的选择。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考