数据完整存储与内存的数据集类+节点预测与边预测任务实践
PyG使用数据的一般过程
- 从网络上下载数据原始文件;
- 对数据原始文件做处理,为每一个图样本生成一个**
Data
对象**; - 对每一个
Data
对象执行数据处理,使其转换成新的Data
对象; - 过滤
Data
对象; - 保存
Data
对象到文件; - 获取
Data
对象,在每一次获取Data
对象时,都先对Data
对象做数据变换(于是获取到的是数据变换后的Data
对象)。
任务实践
节点预测任务
## 定义GAT
class GAT(torch.nn.Module):
def __init__(self, num_features, hidden_channels_list, num_classes):
super(GAT, self).__init__()
torch.manual_seed(12345)
hns = [num_features] + hidden_channels_list
conv_list = []
for idx in range(len(hidden_channels_list)):
conv_list.append((GATConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
conv_list.append(ReLU(inplace=True),)
self.convseq = Sequential('x, edge_index', conv_list)
self.linear = Linear(hidden_channels_list[-1], num_classes)
def forward(self, x, edge_index):
x = self.convseq(x, edge_index)
x = F.dropout(x, p=0.5, training=self.training)
x = self.linear(x)
return x
## 获取数据集并进行分析
import os.path as osp
from torch_geometric.utils import negative_sampling
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges
dataset = Planetoid('dataset', 'Cora', transform=T.NormalizeFeatures())
data = dataset[0]
data.train_mask = data.val_mask = data.test_mask = data.y = None # 不再有用
print(data.edge_index.shape)
# torch.Size([2, 10556])
data = train_test_split_edges(data)
for key in data.keys:
print(key, getattr(data, key).shape)
# x torch.Size([2708, 1433])
# val_pos_edge_index torch.Size([2, 263])
# test_pos_edge_index torch.Size([2, 527])
# train_pos_edge_index torch.Size([2, 8976])
# train_neg_adj_mask torch.Size([2708, 2708])
# val_neg_edge_index torch.Size([2, 263])
# test_neg_edge_index torch.Size([2, 527])
# 263 + 527 + 8976 = 9766 != 10556
# 263 + 527 + 8976/2 = 5278 = 10556/2
边预测任务
## 构造神经网络
import torch
from torch_geometric.nn import GCNConv
class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(Net, self).__init__()
self.conv1 = GCNConv(in_channels, 128)
self.conv2 = GCNConv(128, out_channels)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index)
x = x.relu()
return self.conv2(x, edge_index)
def decode(self, z, pos_edge_index, neg_edge_index):
edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
def decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()
## 单epoch训练
def get_link_labels(pos_edge_index, neg_edge_index):
num_links = pos_edge_index.size(1) + neg_edge_index.size(1)
link_labels = torch.zeros(num_links, dtype=torch.float)
link_labels[:pos_edge_index.size(1)] = 1.
return link_labels
def train(data, model, optimizer):
model.train()
neg_edge_index = negative_sampling(
edge_index=data.train_pos_edge_index,
num_nodes=data.num_nodes,
num_neg_samples=data.train_pos_edge_index.size(1))
optimizer.zero_grad()
z = model.encode(data.x, data.train_pos_edge_index)
link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index)
link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index).to(data.x.device)
loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
loss.backward()
optimizer.step()
return loss
## 单epoch验证与测试过程
@torch.no_grad()
def test(data, model):
model.eval()
z = model.encode(data.x, data.train_pos_edge_index)
results = []
for prefix in ['val', 'test']:
pos_edge_index = data[f'{prefix}_pos_edge_index']
neg_edge_index = data[f'{prefix}_neg_edge_index']
link_logits = model.decode(z, pos_edge_index, neg_edge_index)
link_probs = link_logits.sigmoid()
link_labels = get_link_labels(pos_edge_index, neg_edge_index)
results.append(roc_auc_score(link_labels.cpu(), link_probs.cpu()))
return results
参考资料:
https://github.com/datawhalechina/team-learning-nlp/tree