图神经网络——超大规模数据集类的创建和图预测任务实践
当数据集规模超级大时,很难有足够大的内存完全存下所有数据。因此需要一个按需加载样本到内存的数据集类。
1 Dataset
基类
1.1 Dataset
基类介绍
在PyG中,通过继承torch_geometric.data.Dataset
基类来自定义一个按需加载样本到内存的数据集类。
继承此基类相比较继承torch_geometric.data.InMemoryDataset
基类要多实现以下方法:
len()
:返回数据集中的样本的数量。get()
:实现加载单个图的操作。注意:在内部,getitem()返回通过调用get()来获取Data对象,并根据transform参数对它们进行选择性转换。
1.2 继承torch_geometric.data.Dataset基类的代码实现
1 | import os.path as osp |
1.3 特殊情况
- download/process步骤可以跳过
对于无需下载数据集原文件的情况,不重写(override)download
方法即可跳过下载。
对于无需对数据集做预处理的情况,不重写process
方法即可跳过预处理。 - 有些Dataset类无需定义
可以不用定义一个Dataset
类,而直接生成一个Dataloader
对象,直接用于训练:或者将一个列表的1
2
3
4from torch_geometric.data import Data, DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)Data
对象组成一个batch
:1
2
3
4from torch_geometric.data import Data, Batch
data_list = [Data(...), ..., Data(...)]
loader = Batch.from_data_list(data_list, batch_size=32)
2 图样本封装成批(BATCHING)与DataLoader
类
2.1 合并小图组成大图
PyTorch Geometric中采用的是将多个图封装成批的方式,将小图作为连通组件(connected component)的形式合并,构建一个大图。于是小图的邻接矩阵存储在大图邻接矩阵的对角线上。
该方法的优势在于:
- 依靠消息传递方案的GNN运算不需要被修改。
- 没有额外的计算或内存的开销。(因为它们是以稀疏的方式保存的,只保留非零项)
通过torch_geometric.data.DataLoader
类,多个小图被封装成一个大图。
2.2 小图的属性增值与拼接
将小图存储到大图中时需要对小图的属性做一些修改,比如对节点序号增值,PyTorch Geometric的DataLoader
类会自动对edge_index
张量增值,增加的值为当前被处理图的前面的图的累积节点数量。
2.3 图的匹配(Pairs of Graphs)
不同类型的节点数量不一致,edge_index
边的源节点与目标节点进行增值操作不同。
2.4 二部图(Bipartite Graphs)
部图是图论中的一种特殊模型。设$G=(V,E)$是一个无向图,如果顶点$V$可分割为两个互不相交的子集$(A,B)$,并且图中的每条边$(i,j)$所关联的两个顶点i和j分别属于这两个不同的顶点集$(i in A,j in B)$,则称图$G$为一个二部图。它的邻接矩阵定义两种类型的节点之间的连接关系。一般来说,不同类型的节点数量不需要一致,于是二部图的邻接矩阵$A \in {0,1}^{N \times M}$可能为平方矩阵,即可能有$N \neq M$。
2.5 在新的维度上做拼接
图级别属性或预测目标,Data
对象的属性需要在一个新的维度上做拼接,此时形状为[num_features]
的属性列表应该被返回为[num_examples, num_features]
,而不是[num_examples * num_features]
。
3 创建超大规模数据集类实践
PCQM4M-LSC是一个分子图的量子特性回归数据集,它包含了3,803,453个图。
定义的数据集类如下:
1 | import os |
在生成一个该数据集类的对象时:
- 首先会检查指定的文件夹下是否存在
data.csv.gz
文件,如果不在,则会执行download
方法,这一过程是在运行super
类的__init__
方法中发生的。 - 然后程序继续执行
__init__
方法的剩余部分,读取data.csv.gz
文件,获取存储图信息的smiles
格式的字符串,以及回归预测的目标homolumogap
。由smiles
格式的字符串转成图的过程在get()
方法中实现,这样在生成一个ataLoader
变量时,通过指定num_workers
可以实现并行执行生成多个图。
4 图预测任务实践
4.1 通过试验寻找最佳超参数
1 | !/bin/sh |
这段代码运行后,程序会在saves
目录下创建一个task_name
参数指定名称的文件夹用于记录试验过程。试验运行过程中,所有的print
输出都会写入到试验文件夹下的output
文件,tensorboard.SummaryWriter
记录的信息也存储在试验文件夹下的文件中。
修改上方的命令再执行,即可试验不同的超参数,所有试验的过程与结果信息都存储于saves
文件夹下。启动TensorBoard
会话,选择saves
文件夹,即可查看所有试验的过程与结果信息。