图神经网络——消息传递网络

消息传递(Message Passing) 指的是目标节点$S1$的邻居$\mathcal{N(S1)}$——B1、B2、B3,这些邻居节点根据一定的规则将信息(特征),汇总到目标节点上。信息汇总中最简单的规则就是逐个元素相加。
在pytorch-geometric的官方文档中,消息传递图神经网络被描述为:
$$
\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}i^{(k-1)}, \square{j \in \mathcal{N}(i)} , \phi^{(k)}\left(\mathbf{x}i^{(k-1)}, \mathbf{x}j^{(k-1)},\mathbf{e}{j,i}\right) \right),
$$
其中,$\mathbf{e}
{j,i} \in \mathbb{R}^D$ 表示从节点$j$到节点$i$的边的属性,$\mathbf{x}^{(k-1)}_i\in\mathbb{R}^F$表示$(k-1)$层中节点$i$的节点表征,$\square$表示聚合策略,$\gamma$和$\phi$表示一些神经网络方法,比如MLPs多层感知器、LSTM等。
从公式中可以看出,目标节点$x_i$在k层的特征可以通过$x_i$在上一层(k-1层)的特征与其相邻节点$x_j$在上一层(k-1层)的特征以及相邻节点到目标节点的边的特征,这三个特征在k层通过$\square$的聚合策略(aggregate),通过一个$\gamma$在k层的分析方法来导出目标节点$x_i$的特征。

2. MessagePassing基类

Pytorch Geometric(PyG)提供了MessagePassing基类,通过继承基类,并定义message()方法、update()方法、aggregate()方法,可以构造消息传递图神经网络。

2.1 MessagePassing.__init__(aggr=”add”, flow=”source_to_target”, node_dim=-2)

aggr:定义聚合方案(“add”、”mean”或 “max”),默认值”add”;
flow:定义消息传递方向(“source_to_target”或 “target_to_source”),默认值”source_to_target”;
node_dim:定义沿哪个维度传播,指的是节点表征张量(Tensor)的哪一个维度是节点维度,默认值-2(第0维)。

2.2 MessagePassing.propagate(edge_index, size=None, **kwargs)

这是一个集成方法,调用其会依次调用message、aggregate、update方法。
edge_index:边的端点的索引,当flow=”source_to_target”时,节点edge_index[0]的信息将被传递到节点edge_index[1];当flow=”target_to_source”时,节点edge_index[1]的信息将被传递到节点edge_index[0]。
size:邻接节点的数量与中心节点的数量,默认值None(对称矩阵);
**kwaegs:图的其他属性或额外的数据。

2.3 MessagePassing.message(…)

以函数的方式构造消息;
flow=”source_to_target”,此方式下,message方法负责产生source node需要传出的信息。

2.4 MessagePassing.update(aggr_out, …)

为每个节点$i \in \mathcal{V}$更新节点表征。

2.5 MessagePassing.aggregate(…)

将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式有sum, mean和max。

3. MessagePassing子类

3.1 GCNConv类

以继承MessagePassing基类的GCNConv类为例,可以实现一个简单的GNN。
GCNConv的公式如下:
$$
\mathbf{x}i^{(k)} = \sum{j \in \mathcal{N}(i) \cup { i }} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right)
$$
其中,邻接节点的表征$\mathbf{x}_j^{(k-1)}$首先通过与权重矩阵$\mathbf{\Theta}$相乘进行变换,然后按端点的度(degree)$\deg(i), \deg(j)$进行归一化处理,最后进行求和。这个公式可以分为以下几个步骤:

向邻接矩阵添加自环边。
对节点表征做线性转换。
计算归一化系数。
归一化邻接节点的节点表征。
将相邻节点表征相加(”求和 “聚合)。

GCNConv继承了MessagePassing,并以”求和”作为领域节点信息聚合方式。该层的所有逻辑都在forward()方法中:
1.通过torch_geometric.utils.add_self_loops()函数向边索引添加自循环边,目的是改进原始不考虑中心节点自身的信息量的问题;
2.通过torch.nn.Linear实例对节点表征进行线性变换;
3.归一化系数是由每个节点的节点度得出的,它被转换为每条边的节点度。结果被保存在形状为[num_edges,]的变量norm中。

3.2 实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from torch_geometric.datasets import Planetoid
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_sparse import SparseTensor

class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
#in_channels可理解为输入通道数,out_channels可理解为卷积核数量
super(GCNConv, self).__init__(aggr='add', flow='source_to_target') #继承,策略为合并
# "Add" aggregation (Step 5).
# flow='source_to_target' 表示消息从源节点传播到目标节点
self.lin = torch.nn.Linear(in_channels, out_channels)

def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]

# Step 1: 添加自循环到节点特征矩阵(Add self-loops to the adjacency matrix.)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

# Step 2: 节点特征矩阵的线性变换(Linearly transform node feature matrix.)
x = self.lin(x)

# Step 3: Compute normalization.
row, col = edge_index #用行、列描述特征矩阵
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

# Step 4-5: Start propagating messages.
adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
# 此处传的不再是edge_index,而是SparseTensor类型的Adjancency Matrix
return self.propagate的(adjmat, x=x, norm=norm, deg=deg.view((-1, 1)))
# 此处省略MessagePassing.propagate的代码.

def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# deg_i has shape [E, 1]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j

def aggregate(self, inputs, index, ptr, dim_size):
print('self.aggr:', self.aggr)
print("`aggregate` is called")
return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)

def message_and_aggregate(self, adj_t, x, norm):
print('`message_and_aggregate` is called')
# 没有实现真实的消息传递与消息聚合的操作

def update(self, inputs, deg):
print(deg)
return inputs

dataset = Planetoid(root='dataset', name='Cora')
data = dataset[0]

net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
# print(h_nodes.shape)

参考资料

1.datawhale-GNN开源学习资料

欢迎关注我们的公众号
0%