消息传递(Message Passing) 指的是目标节点$S1$的邻居$\mathcal{N(S1)}$——B1、B2、B3,这些邻居节点根据一定的规则将信息(特征),汇总到目标节点上。信息汇总中最简单的规则就是逐个元素相加。
在pytorch-geometric的官方文档中,消息传递图神经网络被描述为:
$$
\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}i^{(k-1)}, \square{j \in \mathcal{N}(i)} , \phi^{(k)}\left(\mathbf{x}i^{(k-1)}, \mathbf{x}j^{(k-1)},\mathbf{e}{j,i}\right) \right),
$$
其中,$\mathbf{e}{j,i} \in \mathbb{R}^D$ 表示从节点$j$到节点$i$的边的属性,$\mathbf{x}^{(k-1)}_i\in\mathbb{R}^F$表示$(k-1)$层中节点$i$的节点表征,$\square$表示聚合策略,$\gamma$和$\phi$表示一些神经网络方法,比如MLPs多层感知器、LSTM等。
从公式中可以看出,目标节点$x_i$在k层的特征可以通过$x_i$在上一层(k-1层)的特征与其相邻节点$x_j$在上一层(k-1层)的特征以及相邻节点到目标节点的边的特征,这三个特征在k层通过$\square$的聚合策略(aggregate),通过一个$\gamma$在k层的分析方法来导出目标节点$x_i$的特征。
2. MessagePassing基类
Pytorch Geometric(PyG)提供了MessagePassing基类,通过继承基类,并定义message()方法、update()方法、aggregate()方法,可以构造消息传递图神经网络。
2.1 MessagePassing.__init__(aggr=”add”, flow=”source_to_target”, node_dim=-2)
aggr:定义聚合方案(“add”、”mean”或 “max”),默认值”add”;
flow:定义消息传递方向(“source_to_target”或 “target_to_source”),默认值”source_to_target”;
node_dim:定义沿哪个维度传播,指的是节点表征张量(Tensor)的哪一个维度是节点维度,默认值-2(第0维)。
2.2 MessagePassing.propagate(edge_index, size=None, **kwargs)
这是一个集成方法,调用其会依次调用message、aggregate、update方法。
edge_index:边的端点的索引,当flow=”source_to_target”时,节点edge_index[0]的信息将被传递到节点edge_index[1];当flow=”target_to_source”时,节点edge_index[1]的信息将被传递到节点edge_index[0]。
size:邻接节点的数量与中心节点的数量,默认值None(对称矩阵);
**kwaegs:图的其他属性或额外的数据。
2.3 MessagePassing.message(…)
以函数的方式构造消息;
flow=”source_to_target”,此方式下,message方法负责产生source node需要传出的信息。
2.4 MessagePassing.update(aggr_out, …)
为每个节点$i \in \mathcal{V}$更新节点表征。
2.5 MessagePassing.aggregate(…)
将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式有sum, mean和max。
3. MessagePassing子类
3.1 GCNConv类
以继承MessagePassing基类的GCNConv类为例,可以实现一个简单的GNN。
GCNConv的公式如下:
$$
\mathbf{x}i^{(k)} = \sum{j \in \mathcal{N}(i) \cup { i }} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right)
$$
其中,邻接节点的表征$\mathbf{x}_j^{(k-1)}$首先通过与权重矩阵$\mathbf{\Theta}$相乘进行变换,然后按端点的度(degree)$\deg(i), \deg(j)$进行归一化处理,最后进行求和。这个公式可以分为以下几个步骤:
向邻接矩阵添加自环边。
对节点表征做线性转换。
计算归一化系数。
归一化邻接节点的节点表征。
将相邻节点表征相加(”求和 “聚合)。
GCNConv继承了MessagePassing,并以”求和”作为领域节点信息聚合方式。该层的所有逻辑都在forward()方法中:
1.通过torch_geometric.utils.add_self_loops()函数向边索引添加自循环边,目的是改进原始不考虑中心节点自身的信息量的问题;
2.通过torch.nn.Linear实例对节点表征进行线性变换;
3.归一化系数是由每个节点的节点度得出的,它被转换为每条边的节点度。结果被保存在形状为[num_edges,]的变量norm中。
3.2 实现
1 | from torch_geometric.datasets import Planetoid |