3 min read

图异常检测:GraphSAGE 归纳式方法及其与GCN的对比

引言

图神经网络(Graph Neural Networks, GNNs)近年来在处理图结构数据方面取得了显著进展。在众多GNN模型中,图卷积网络(Graph Convolutional Networks, GCN)和GraphSAGE(Graph SAmple and aggreGatE)是两种最具代表性的方法。本文将深入探讨GraphSAGE的工作原理,并与GCN进行全面对比,帮助读者理解这两种方法的异同点及适用场景。

图神经网络的基本概念

在深入讨论GraphSAGE和GCN之前,我们先简要回顾图神经网络的基本概念。

图数据由节点(Nodes)和边(Edges)组成,可以表示为G = (V, E),其中V是节点集合,E是边集合。图神经网络的核心任务是学习节点的表示(Node Embeddings),即将每个节点映射到低维向量空间,使得在原图中相似的节点在向量空间中也相近。

图神经网络通常采用消息传递(Message Passing)机制,通过聚合邻居节点的信息来更新中心节点的表示。这一过程可以迭代多次,使得每个节点能够捕获更广泛的结构信息。

GCN:基于谱图理论的图卷积

GCN的基本原理

GCN由Kipf和Welling在2016年提出,是一种基于谱图理论的图卷积网络。GCN的核心思想是将传统卷积神经网络(CNN)中的卷积操作推广到图结构数据上。

GCN的层更新公式为:

H(l+1)=σ(D~12A~D~12H(l)W(l))

其中: - A~=A+I 是添加了自环的邻接矩阵 - D~A~ 的度矩阵 - H(l) 是第 l 层的节点特征 - W(l) 是可学习的权重矩阵 - σ 是非线性激活函数

GCN的优缺点

优点: 1. 理论基础扎实,基于谱图理论 2. 模型简洁,易于实现 3. 在小型图上表现良好

缺点: 1. 转导式学习(Transductive Learning),难以处理未见过的节点 2. 需要整个图的邻接矩阵,难以处理大规模图 3. 固定的聚合方式,缺乏灵活性 4. 存在过平滑(Over-smoothing)问题,层数增加时性能下降

GraphSAGE:归纳式图表示学习

GraphSAGE的基本原理

GraphSAGE由Hamilton等人在2017年提出,是一种归纳式图表示学习方法。与GCN不同,GraphSAGE通过采样和聚合邻居信息来生成节点嵌入,使得模型可以处理大规模图和动态变化的图结构。

GraphSAGE的关键创新在于: 1. 通过采样固定数量的邻居,避免处理整个图 2. 将节点自身信息与邻居信息进行聚合 3. 支持多种可学习的聚合函数

GraphSAGE的更新过程包含三个步骤:

  1. 邻居采样:对每个节点,随机采样固定数量的邻居

  2. 信息聚合:使用聚合函数整合邻居信息 hN(v)k=AGGREGATEk({huk1,uN(v)})

  3. 节点更新:将节点自身表示hvk1与聚合后的邻居信息hN(v)k进行拼接,然后通过全连接层和非线性变换: hvk=σ(WkCONCAT(hvk1,hN(v)k)) 最后进行归一化: hvk=hvk/hvk2

其中,N(v) 是节点 v 的采样邻居集合,AGGREGATE 是聚合函数。这种设计使得GraphSAGE能够同时利用节点自身特征和局部结构信息。

GraphSAGE的聚合函数

GraphSAGE提供了多种聚合函数选择:

  1. Mean Aggregator:计算邻居特征的平均值 AGGREGATEmean=MEAN({hu,uN(v)})

  2. Max-pooling Aggregator:对邻居特征进行非线性变换后取最大值 AGGREGATEmax=MAX({σ(Wpoolhu+b),uN(v)})

  3. LSTM Aggregator:使用LSTM处理邻居特征序列 AGGREGATELSTM=LSTM({hu,uN(v)})

GraphSAGE的优缺点

优点: 1. 归纳式学习(Inductive Learning),可以处理未见过的节点 2. 通过邻居采样策略,可以处理大规模图 3. 灵活的聚合函数,适应不同类型的图结构 4. 可以通过小批量训练减少内存需求

缺点: 1. 采样策略可能导致信息损失 2. 聚合函数的选择需要经验 3. 多层堆叠时仍存在过平滑问题

GraphSAGE与GCN的对比分析

计算范式

GCN:采用转导式学习,需要在训练时看到整个图,难以处理新节点。

GraphSAGE:采用归纳式学习,通过学习聚合函数,可以为未见过的节点生成嵌入。

可扩展性

GCN:需要整个图的邻接矩阵,空间复杂度为O(|V|²),难以处理大规模图。

GraphSAGE:通过邻居采样,每次只处理固定数量的邻居,空间复杂度与采样数量相关,可以处理大规模图。

聚合方式

GCN:使用固定的归一化拉普拉斯矩阵进行聚合,缺乏灵活性。

GraphSAGE:提供多种聚合函数(均值、最大值、LSTM等),可以根据图的特性选择合适的聚合方式。

内存需求

GCN:需要将整个图加载到内存中。

GraphSAGE:可以通过小批量训练和邻居采样减少内存需求。

适用场景

GCN:适合小型静态图,如引文网络、社交网络的子图等。

GraphSAGE:适合大规模图、动态变化的图,以及需要处理新节点的场景,如大型社交网络、推荐系统等。

实验对比

在实际应用中,GraphSAGE和GCN表现出以下主要差异:

  1. 可扩展性
    • GCN需要处理整个图的邻接矩阵,难以扩展到大规模图
    • GraphSAGE通过邻居采样,可以高效处理包含数百万节点的图
  2. 内存需求
    • GCN需要将整个图加载到内存中
    • GraphSAGE只需要维护采样得到的子图,内存需求显著降低
  3. 训练效率
    • 在小型图上,两者训练速度相当
    • 在大型图上,GraphSAGE通过采样策略大幅减少计算量
  4. 归纳能力
    • GCN难以处理未见过的节点
    • GraphSAGE可以为新节点生成嵌入,适合动态变化的图

这些差异使得GraphSAGE在工业级大规模图数据上更具优势,而GCN在小型静态图上仍然是一个简单有效的选择。

GraphSAGE的实际应用

GraphSAGE在多个领域有广泛应用:

  1. 社交网络分析:用户推荐、社区发现、影响力分析
  2. 推荐系统:物品推荐、用户兴趣建模
  3. 生物信息学:蛋白质相互作用网络、药物发现
  4. 知识图谱:实体关系预测、知识补全
  5. 交通网络:交通流量预测、路径规划

代码实现

GraphSAGE实现

下面是使用PyTorch Geometric实现GraphSAGE的核心代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, 
                 dropout=0.5, aggr='mean'):
        super(GraphSAGE, self).__init__()
        
        self.num_layers = num_layers
        self.dropout = dropout
        
        self.convs = nn.ModuleList()
        
        # 输入层
        self.convs.append(SAGEConv(in_channels, hidden_channels, aggr=aggr))
        
        # 隐藏层
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels, aggr=aggr))
            
        # 输出层
        self.convs.append(SAGEConv(hidden_channels, out_channels, aggr=aggr))
    
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            
            if i < self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)
        
        return x

GCN实现

下面是使用PyTorch Geometric实现GCN的核心代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, 
                 dropout=0.5):
        super(GCN, self).__init__()
        
        self.num_layers = num_layers
        self.dropout = dropout
        
        self.convs = nn.ModuleList()
        
        # 输入层
        self.convs.append(GCNConv(in_channels, hidden_channels))
        
        # 隐藏层
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            
        # 输出层
        self.convs.append(GCNConv(hidden_channels, out_channels))
    
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            
            if i < self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)
        
        return x

两种模型的训练代码

下面是训练GraphSAGE和GCN模型的代码示例:

def train(model, data, optimizer):
    model.train()
    optimizer.zero_grad()
    
    # 前向传播
    out = model(data.x, data.edge_index)
    
    # 计算损失
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    
    # 反向传播
    loss.backward()
    optimizer.step()
    
    return loss.item()

def test(model, data):
    model.eval()
    
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        
        # 计算准确率
        train_acc = pred[data.train_mask].eq(data.y[data.train_mask]).sum().item() / data.train_mask.sum().item()
        val_acc = pred[data.val_mask].eq(data.y[data.val_mask]).sum().item() / data.val_mask.sum().item()
        test_acc = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
    
    return train_acc, val_acc, test_acc

结论

GraphSAGE作为一种归纳式图表示学习方法,在处理大规模图和动态图方面具有显著优势。与GCN相比,GraphSAGE通过邻居采样和灵活的聚合函数,实现了更好的可扩展性和适应性。

参考文献

  1. Hamilton, W. L., Ying, R., & Leskovec, J. (2017). Inductive representation learning on large graphs. In Advances in Neural Information Processing Systems (pp. 1024-1034).

  2. Kipf, T. N., & Welling, M. (2016). Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907.