「机器学习」图神经网络(GNN)的解决方法

「机器学习」图神经网络(GNN)的解决方法
2022年11月16日 11:16 潮流科技电商

图神经网络(GNN) 是在机器学习中利用图结构数据的强大工具。图是灵活的数据结构,可以对许多不同类型的关系进行建模,并已经用于各种应用,如交通预测、谣言和假新闻检测、疾病传播建模以及了解分子为何产生气味等。

图表可以对许多不同类型的数据之间的关系进行建模,包括网页(左)、社交关系(中)或分子(右)

作为机器学习 (ML) 的标准,GNN 假设训练样本是随机均匀选择的(即,是一个独立且同分布的或“IID”样本)。使用标准学术数据集很容易做到这一点,这些数据集专为研究分析而创建,因此每个节点都已标记。然而,在许多现实世界的场景中,数据没有标签,并且标记数据可能是一个繁重的过程,涉及熟练的人类评估者,这使得标记所有节点变得困难。此外,有偏差的训练数据是一个常见问题,因为选择节点进行标记的行为通常不是 IID。例如,有时使用固定启发式方法来选择数据子集(共享某些特征)进行标记,而其他时候,人类分析师使用复杂的领域知识单独选择数据项进行标记。

本地化训练数据是图结构数据中表现出的典型非 IID 偏差。左图显示了一个橙色节点并扩展到它周围的节点。相反,用于标记节点的 IID 训练样本均匀分布,如右侧的采样过程所示。

为了量化训练存在的偏差量,可以使用测量两个不同概率分布之间偏移量的方法,其中偏移量可以被认为是偏差量。随着这种转变规模的扩大,机器学习模型更难以从有偏的训练集中进行泛化。这种情况可能会严重损害泛化性——在学术数据集上,Google观察到域转移导致性能下降 15-20%(由F1 分数衡量)。

在NeurIPS 2021上发表的 “ Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data ”中,Google介绍了一种在有偏数据上使用 GNN 的解决方案。这种方法称为 Shift-Robust GNN (SR-GNN),旨在解决有偏差的训练数据与图的真实推理分布之间的分布差异。SR-GNN 使 GNN 模型适应标记为训练的节点与数据集的其余部分之间存在的分布偏移。Google在用于半监督学习的通用 GNN 基准数据集上使用有偏训练数据集进行的各种实验中说明了 SR-GNN 的有效性,并表明 SR-GNN 在准确性方面优于其他 GNN 基线,将有偏训练数据的负面影响减少了 30 –40%。

分布变化对性能的

影响为了展示分布变化如何影响 GNN 性能,Google首先为已知的学术数据集生成了一些有偏差的训练集。然后为了理解效果,Google绘制了泛化(测试准确度)与分布偏移量度(中心矩差异1,CMD)的关系图。例如,考虑著名的PubMed引文数据集,可以将其视为一个图,其中节点是医学研究论文,边表示它们之间的引文。当Google为 PubMed 生成有偏差的训练数据时,该图如下所示:

分布变化对 PubMed 数据集的影响。对于 100 个有偏差的训练集样本,性能 ( F1 ) 显示在 y 轴上与分布偏移、中心矩差异 ( CMD ) 在 x 轴上。随着分布偏移的增加,模型的准确性下降。

在这里可以观察到数据集中的分布变化与分类精度之间存在很强的负相关:随着 CMD 的增加,性能(F1)下降。也就是说,GNN 可能难以泛化,因为它们的训练数据看起来不像测试数据集。

为了解决这个问题,Google提出了一个 shift-robust 正则化器(在思想上类似于域不变学习),以最小化训练数据和来自未标记数据的 IID 样本之间的分布偏移。为此,Google在模型训练时实时测量域转移(例如,通过 CMD),并在此基础上应用直接惩罚,迫使模型尽可能多地忽略训练偏差。这迫使模型为训练数据学习的特征编码器也可以有效地处理任何未标记的数据,这些数据可能来自不同的分布。

下图显示了与传统 GNN 模型相比的情况。Google仍然有相同的输入(节点特征X和邻接矩阵A)和相同数量的层。然而,在最终嵌入时,来自 GNN 的层 ( k ) 的Z k与来自未标记数据点的嵌入进行比较,以验证模型是否正确编码它们。

SR-GNN 为深度 GNN 模型添加了两种正则化。首先,域移位正则化(λ项)最小化标记( Z k)和未标记(Z IID)数据的隐藏表示之间的距离。其次,可以更改示例的实例权重 ( β ) 以进一步逼近真实分布

Google将此正则化写为基于训练数据表示与真实数据分布之间的距离的模型损失公式中的附加项(论文中提供了完整的公式)。

在Google的实验中,Google比较了Google的方法和一些标准的图神经网络模型,以衡量它们在节点分类任务上的表现。Google证明,添加 SR-GNN 正则化可以使训练数据标签有偏差的分类任务提高 30-40%。

使用节点分类的 SR-GNN 与 PubMed 数据集上的有偏训练数据的比较。SR-GNN优于七个基线,包括DGI、GCN、GAT、SGC和APPNP。

通过实例重新加权对线性 GNN 进行 Shift-Robust 正则化

此外,值得注意的是,还有另一类 GNN 模型(例如,APPNP、SimpleGCN等)基于线性运算来加速其图卷积。Google还研究了如何在存在有偏差的训练数据的情况下使这些模型更可靠。虽然由于架构不同,无法直接应用相同的正则化机制,但Google可以通过根据训练实例与近似真实分布的距离重新加权训练实例来“纠正”训练偏差。这允许在不通过模型传递梯度的情况下纠正有偏差的训练数据的分布。

最后,对于深度和线性 GNN 的两种正则化可以组合成一个针对损失的广义正则化,它结合了域正则化和实例重新加权(详细信息,包括损失公式,可在论文中找到)。

有偏差的训练数据在现实世界场景中很常见,并且可能由于多种原因而出现,包括标记大量数据的困难、用于选择标记节点的各种启发式或不一致的技术、延迟的标签分配以及其他。Google提出了一个通用框架(SR-GNN),可以减少有偏差的训练数据的影响,并且可以应用于各种类型的 GNN,包括更深的 GNN 和这些模型的最新线性化(浅)版本。

财经自媒体联盟更多自媒体作者

新浪首页 语音播报 相关新闻 返回顶部