图神经网络——基本图论与PyG库

图神经网络——超大规模数据集类的创建和图预测任务实践

当数据集规模超级大时,很难有足够大的内存完全存下所有数据。因此需要一个按需加载样本到内存的数据集类。

1 Dataset基类

1.1 Dataset基类介绍

在PyG中,通过继承torch_geometric.data.Dataset基类来自定义一个按需加载样本到内存的数据集类。
继承此基类相比较继承torch_geometric.data.InMemoryDataset基类要多实现以下方法:

  • len():返回数据集中的样本的数量。
  • get():实现加载单个图的操作。注意:在内部,getitem()返回通过调用get()来获取Data对象,并根据transform参数对它们进行选择性转换。
1.2 继承torch_geometric.data.Dataset基类的代码实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os.path as osp

import torch
from torch_geometric.data import Dataset, download_url

class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)

@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]

@property
def processed_file_names(self):
return ['data_1.pt', 'data_2.pt', ...]

def download(self):
# Download to `self.raw_dir`.
path = download_url(url, self.raw_dir)
...

def process(self):
i = 0
for raw_path in self.raw_paths:
# Read data from `raw_path`.
data = Data(...)

if self.pre_filter is not None and not self.pre_filter(data):
continue

if self.pre_transform is not None:
data = self.pre_transform(data)

torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
i += 1

def len(self):
return len(self.processed_file_names)

def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data
1.3 特殊情况
  • download/process步骤可以跳过
    对于无需下载数据集原文件的情况,不重写(override)download方法即可跳过下载。
    对于无需对数据集做预处理的情况,不重写process方法即可跳过预处理。
  • 有些Dataset类无需定义
    可以不用定义一个Dataset类,而直接生成一个Dataloader对象,直接用于训练:
    1
    2
    3
    4
    from torch_geometric.data import Data, DataLoader

    data_list = [Data(...), ..., Data(...)]
    loader = DataLoader(data_list, batch_size=32)
    或者将一个列表的Data对象组成一个batch
    1
    2
    3
    4
    from torch_geometric.data import Data, Batch

    data_list = [Data(...), ..., Data(...)]
    loader = Batch.from_data_list(data_list, batch_size=32)

2 图样本封装成批(BATCHING)与DataLoader

2.1 合并小图组成大图

PyTorch Geometric中采用的是将多个图封装成批的方式,将小图作为连通组件(connected component)的形式合并,构建一个大图。于是小图的邻接矩阵存储在大图邻接矩阵的对角线上。
该方法的优势在于:

  • 依靠消息传递方案的GNN运算不需要被修改。
  • 没有额外的计算或内存的开销。(因为它们是以稀疏的方式保存的,只保留非零项)

通过torch_geometric.data.DataLoader类,多个小图被封装成一个大图。

2.2 小图的属性增值与拼接

将小图存储到大图中时需要对小图的属性做一些修改,比如对节点序号增值,PyTorch Geometric的DataLoader类会自动对edge_index张量增值,增加的值为当前被处理图的前面的图的累积节点数量。

2.3 图的匹配(Pairs of Graphs)

不同类型的节点数量不一致,edge_index边的源节点与目标节点进行增值操作不同。

2.4 二部图(Bipartite Graphs)

部图是图论中的一种特殊模型。设$G=(V,E)$是一个无向图,如果顶点$V$可分割为两个互不相交的子集$(A,B)$,并且图中的每条边$(i,j)$所关联的两个顶点i和j分别属于这两个不同的顶点集$(i in A,j in B)$,则称图$G$为一个二部图。它的邻接矩阵定义两种类型的节点之间的连接关系。一般来说,不同类型的节点数量不需要一致,于是二部图的邻接矩阵$A \in {0,1}^{N \times M}$可能为平方矩阵,即可能有$N \neq M$。

2.5 在新的维度上做拼接

图级别属性或预测目标,Data对象的属性需要在一个新的维度上做拼接,此时形状为[num_features]的属性列表应该被返回为[num_examples, num_features],而不是[num_examples * num_features]

3 创建超大规模数据集类实践

PCQM4M-LSC是一个分子图的量子特性回归数据集,它包含了3,803,453个图。
定义的数据集类如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import os.path as osp

import pandas as pd
import torch
from ogb.utils import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import download_url, extract_zip
from rdkit import RDLogger
from torch_geometric.data import Data, Dataset
import shutil

RDLogger.DisableLog('rdApp.*')

class MyPCQM4MDataset(Dataset):

def __init__(self, root):
self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip'
super(MyPCQM4MDataset, self).__init__(root)

filepath = osp.join(root, 'raw/data.csv.gz')
data_df = pd.read_csv(filepath)
self.smiles_list = data_df['smiles']
self.homolumogap_list = data_df['homolumogap']

@property
def raw_file_names(self):
return 'data.csv.gz'

def download(self):
path = download_url(self.url, self.root)
extract_zip(path, self.root)
os.unlink(path)
shutil.move(osp.join(self.root, 'pcqm4m_kddcup2021/raw/data.csv.gz'), osp.join(self.root, 'raw/data.csv.gz'))

def len(self):
return len(self.smiles_list)

def get(self, idx):
smiles, homolumogap = self.smiles_list[idx], self.homolumogap_list[idx]
graph = smiles2graph(smiles)
assert(len(graph['edge_feat']) == graph['edge_index'].shape[1])
assert(len(graph['node_feat']) == graph['num_nodes'])

x = torch.from_numpy(graph['node_feat']).to(torch.int64)
edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
y = torch.Tensor([homolumogap])
num_nodes = int(graph['num_nodes'])
data = Data(x, edge_index, edge_attr, y, num_nodes=num_nodes)
return data

# 获取数据集划分
def get_idx_split(self):
split_dict = replace_numpy_with_torchtensor(torch.load(osp.join(self.root, 'pcqm4m_kddcup2021/split_dict.pt')))
return split_dict

if __name__ == "__main__":
dataset = MyPCQM4MDataset('dataset2')
from torch_geometric.data import DataLoader
from tqdm import tqdm
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)
for batch in tqdm(dataloader):
pass

在生成一个该数据集类的对象时:

  • 首先会检查指定的文件夹下是否存在data.csv.gz文件,如果不在,则会执行download方法,这一过程是在运行super类的__init__方法中发生的。
  • 然后程序继续执行__init__方法的剩余部分,读取data.csv.gz文件,获取存储图信息的smiles格式的字符串,以及回归预测的目标homolumogap。由smiles格式的字符串转成图的过程在get()方法中实现,这样在生成一个ataLoader变量时,通过指定num_workers可以实现并行执行生成多个图。

4 图预测任务实践

4.1 通过试验寻找最佳超参数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#!/bin/sh

python main.py --task_name GINGraphPooling\ # 为当前试验取名
--device 0\
--num_layers 5\ # 使用GINConv层数
--graph_pooling sum\ # 图读出方法
--emb_dim 256\ # 节点嵌入维度
--drop_ratio 0.\
--save_test\ # 是否对测试集做预测并保留预测结果
--batch_size 512\
--epochs 100\
--weight_decay 0.00001\
--early_stop 10\ # 当有early_stop个epoches验证集结果没有提升,则停止训练
--num_workers 4\
--dataset_root dataset # 存放数据集的根目录

这段代码运行后,程序会在saves目录下创建一个task_name参数指定名称的文件夹用于记录试验过程。试验运行过程中,所有的print输出都会写入到试验文件夹下的output文件,tensorboard.SummaryWriter记录的信息也存储在试验文件夹下的文件中。
修改上方的命令再执行,即可试验不同的超参数,所有试验的过程与结果信息都存储于saves文件夹下。启动TensorBoard会话,选择saves文件夹,即可查看所有试验的过程与结果信息。

参考资料

1.datawhale-GNN开源学习资料
2.Dataset类官方文档

0%