图神经网络——基本图论与PyG库
图神经网络——基于图神经网络的图表征学习方法
图表征学习要求在输入节点属性、边和边的属性(如果有的话)得到一个向量作为图的表征,基于图表征进一步的我们可以做图的预测,而图同构网络(Graph Isomorphism Network, GIN)的图表征网络是当前最经典的图表征学习网络。
1.GNN的邻域聚合(消息传递)
GNN的目标是以图结构数据和节点特征作为输入,以学习到节点(或图)的embedding,用于分类任务。
基于邻域聚合的GNN可以拆分为以下三个模块:
- Aggregate:聚合一阶邻域特征。
- Combine:将邻居聚合的特征 与 当前节点特征合并, 以更新当前节点特征。
- Readout(可选):如果是对graph分类,需要将graph中所有节点特征转变成graph特征。
但是Aggregate的三种方式sum、mean、max的表征能力不够强大。

如上图,节点v和v’为中心节点,通过聚合邻居特征生成embedding来分析不同aggregate设置下是否能区分不同的结构。设红绿蓝色节点特征值分别为r,g,b,不考虑combine。
图a中:
mean:左$\frac{1}{2}(b+b)=b$,右$\frac{1}{3}(b+b+b)=b$,无法区分;
max:左$b$,右$b$,无法区分;
sum:左$2b$,右$3b$,可以区分。
图b中:
mean:左$\frac{1}{2}(r+g)$,右$\frac{1}{3}(g+2r)=b$,可以区分;
max:左$max(r,g)$,右$max(r,r,g)$,无法区分;
sum:左$r+g$,右$2r+g$,可以区分。
图c中:
mean:左$\frac{1}{2}(r+g)$,右$\frac{1}{4}(2g+2r)=b$,无法区分;
max:左$max(r,g)$,右$max(r,r,g,g)$,无法区分;
sum:左$r+g$,右$2r+2g$,可以区分。
这说明,sum基本可以学习精确的结构信息、mean偏向学习分布信息,max偏向学习有代表性的元素信息,无法区分某些结构的图,故性能会比sum差一点。
2.Weisfeiler-Lehman Test (WL Test)
图的同构性测试算法(Weisfeiler-Lehman),简称WL Test,是一种用于测试两个图是否同构的算法。
WL Test 的一维形式,类似GNN中的邻接节点聚合。WL Test先迭代地聚合节点及其邻接节点的标签,然后将聚合的标签散列(hash)成新标签,该过程形式化为下方的公示,
$$
L^{h}{u} \leftarrow \operatorname{hash}\left(L^{h-1}{u} + \sum_{v \in \mathcal{N}(U)} L^{h-1}{v}\right)
$$
上式中,$L^{h}{u}$表示节点$u$的第$h$次迭代的标签,第$0$次迭代的标签为节点原始标签。
在迭代过程中,发现两个图之间的节点的标签不同时,就可以确定这两个图是非同构的。需要注意的是节点标签可能的取值只能是有限个数。
WL测试不能保证对所有图都有效,特别是对于具有高度对称性的图,如链式图、完全图、环图和星图,它会判断错误。
给定两个图$G$和$G^{\prime}$,每个节点拥有标签(实际中,一些图没有节点标签,我们可以以节点的度作为标签)。

Weisfeiler-Leman Test 算法通过重复执行以下给节点打标签的过程来实现图是否同构的判断:
- 聚合自身与邻接节点的标签得到一串字符串,自身标签与邻接节点的标签中间用
,分隔,邻接节点的标签按升序排序。排序的原因在于要保证单射性,即保证输出的结果不因邻接节点的顺序改变而改变。

- 标签散列,即标签压缩,将较长的字符串映射到一个简短的标签。

- 给节点重新打上标签。

每重复一次以上的过程,就完成一次节点自身标签与邻接节点标签的聚合。
当出现两个图相同节点标签的出现次数不一致时,即可判断两个图不相似。如果上述的步骤重复一定的次数后,没有发现有相同节点标签的出现次数不一致的情况,那么我们无法判断两个图是否同构。
当两个节点的$h$层的标签一样时,表示分别以这两个节点为根节点的WL子树是一致的。WL子树与普通子树不同,WL子树包含重复的节点。下图展示了一棵以1节点为根节点高为2的WL子树。

3.图相似性评估
WL Test只能判断两个图的相似性,无法衡量图之间的相似性。要衡量两个图的相似性,需要用WL Subtree Kernel方法。该方法的思想是用WL Test算法得到节点的多层的标签,分别统计图中各类标签出现的次数,存于一个向量,这个向量可以作为图的表征**。**两个图的表征向量的内积,即可作为这两个图的相似性估计,内积越大表示相似性越高。

4.图同构网络模型的构建
通过GIN学习的节点表征向量可以用于类似于节点分类、边预测这样的任务。而对于图分类任务。READOUT函数:给定独立的节点的表征向量集,生成整个图的表征向量。
GIN的READOUT模块使用concat+sum,对每次迭代得到的所有节点特征求和得到图的特征,然后拼接起来,公式如下:
$$
h_{G} = \text{CONCAT}(\text{READOUT}\left({h_{v}^{(k)}|v\in G}\right)|k=0,1,\cdots, K)
$$
5.基于图同构网络(GIN)的图表征网络的实现
基于图同构网络的图表征学习主要包含以下两个过程:
- 计算得到节点表征;
- 对图上各个节点的表征做图池化(Graph Pooling),或称为图读出(Graph Readout),得到图的表征(Graph Representation)。
基于图同构网络的图表征模块(GINGraphRepr Module),首先采用GINNodeEmbedding模块对图上每一个节点做节点嵌入(Node Embedding),得到节点表征;然后对节点表征做图池化得到图的表征;最后用一层线性变换对图表征转换为对图的预测。代码实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch from torch import nn from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set from gin_node import GINNodeEmbedding class GINGraphRepr(nn.Module): def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"): """ num_tasks(int, optional):图表征的维度; num_layers(int, optional):卷积层数; emb_dim(int, optional):node embedding的维度; residual(bool, optional):是否添加剩余的连接; drop_ratio (float, optional):dropout rate; JK (str, optional):可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。 raph_pooling (str, optional):node embedding的池化方法,可选的值为"sum","mean","max","attention"和"set2set"。 """ super(GINGraphPooling, self).__init__() self.num_layers = num_layers self.drop_ratio = drop_ratio self.JK = JK self.emb_dim = emb_dim self.num_tasks = num_tasks if self.num_layers < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual) # Pooling function to generate whole-graph embeddings if graph_pooling == "sum": self.pool = global_add_pool elif graph_pooling == "mean": self.pool = global_mean_pool elif graph_pooling == "max": self.pool = global_max_pool elif graph_pooling == "attention": self.pool = GlobalAttention(gate_nn=nn.Sequential( nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1))) elif graph_pooling == "set2set": self.pool = Set2Set(emb_dim, processing_steps=2) else: raise ValueError("Invalid graph pooling type.") if graph_pooling == "set2set": self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks) else: self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks) def forward(self, batched_data): h_node = self.gnn_node(batched_data) h_graph = self.pool(h_node, batched_data.batch) output = self.graph_pred_linear(h_graph) if self.training: return output else: # At inference time, relu is applied to output to ensure positivity # 因为预测目标的取值范围就在 (0, 50] 内 return torch.clamp(output, min=0, max=50) |
attention:基于Attention对节点表征加权求和,使用模块 torch_geometric.nn.glob.GlobalAttention。
set2set:另一种基于Attention对节点表征加权求和的方法,使用模块 torch_geometric.nn.glob.Set2Set。
6.基于图同构网络的节点嵌入模块(GINNodeEmbedding Module)
此节点嵌入模块基于多层GINConv实现结点嵌入的计算。首先用AtomEncoder对其做嵌入得到第0层节点表征然后我们逐层计算节点表征,从第1层开始到第num_layers层,每一层节点表征的计算都以上一层的节点表征h_list[layer]、边edge_index和边的属性edge_attr为输入。
需要注意的是,GINConv的层数越多,此节点嵌入模块的感受野(receptive field)越大,结点i的表征最远能捕获到结点i的距离为num_layers的邻接节点的信息。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch from mol_encoder import AtomEncoder from gin_conv import GINConv import torch.nn.functional as F # GNN to generate node embedding class GINNodeEmbedding(torch.nn.Module): """ Output: node representations """ def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False): super(GINNodeEmbedding, self).__init__() self.num_layers = num_layers self.drop_ratio = drop_ratio self.JK = JK self.residual = residual if self.num_layers < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.atom_encoder = AtomEncoder(emb_dim) # List of GNNs self.convs = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() for layer in range(num_layers): self.convs.append(GINConv(emb_dim)) self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) def forward(self, batched_data): x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr # computing input node embedding h_list = [self.atom_encoder(x)] # 先将类别型原子属性转化为原子表征 for layer in range(self.num_layers): h = self.convs[layer](h_list[layer], edge_index, edge_attr) h = self.batch_norms[layer](h) if layer == self.num_layers - 1: # remove relu for the last layer h = F.dropout(h, self.drop_ratio, training=self.training) else: h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) if self.residual: h += h_list[layer] h_list.append(h) # Different implementations of Jk-concat if self.JK == "last": node_representation = h_list[-1] elif self.JK == "sum": node_representation = 0 for layer in range(self.num_layers + 1): node_representation += h_list[layer] return node_representation |
7.图同构卷积层(GINConv)
图同构卷积层GINConv的数学定义如下:
$$
\mathbf{x}^{\prime}i = h{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot\mathbf{x}i + \sum{j \in \mathcal{N}(i)} \mathbf{x}_j \right)
$$
构建代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch from torch import nn from torch_geometric.nn import MessagePassing import torch.nn.functional as F from ogb.graphproppred.mol_encoder import BondEncoder # GIN convolution along the graph structure class GINConv(MessagePassing): def __init__(self, emb_dim): #emb_dim (int): node embedding dimensionality super(GINConv, self).__init__(aggr = "add") self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim)) self.eps = nn.Parameter(torch.Tensor([0])) self.bond_encoder = BondEncoder(emb_dim = emb_dim) def forward(self, x, edge_index, edge_attr): edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边表征 out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) return out def message(self, x_j, edge_attr): return F.relu(x_j + edge_attr) def update(self, aggr_out): return aggr_out |