之前我的实现方式相对而言麻烦且准确率不够好,只能达到65%左右的准确率(Cora上),这里介绍直接用PyG封装好的GAT函数实现:
import torch
import math
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import GATConv
from torch_geometric.utils import add_self_loops,degree
from torch_geometric.datasets import Planetoid
import ssl
import torch.nn.functional as F
class Net(torch.nn.Module):
def __init__(self):
super(Net,self).__init__()
self.gat1=GATConv(dataset.num_node_features,