Skip to content

Graph Attention Networks

Graph attention network 用可学习、数据依赖的权重替代了统一的邻居聚合。本文件覆盖 GAT、多头 graph attention、GATv2、Graph Transformer、位置与结构编码以及可扩展性。

  • 在 GCN(文件 3)中,每个 node 都用由 graph 结构(normalised adjacency)决定的固定权重来聚合邻居 feature。若一个 node 有三个邻居,每个邻居大约分到 \(\approx 1/3\) 权重。但现实中邻居的重要性并不相同:来自紧密协作者的消息通常应比来自远关系节点的消息更重要。

  • Graph Attention Network 通过学习“该 attention 哪些邻居”来解决这个问题,机制与第 7 章 transformer 的 attention 相同。与固定的结构权重不同,每个 node 会对邻居动态计算基于内容的 attention score。

GAT:Graph Attention Network

  • GAT(Veličković et al., 2018)在 node 与其邻居之间计算 attention 系数。对 node \(i\) 及其邻居 \(j\)
\[e_{ij} = \text{LeakyReLU}\left(\mathbf{a}^T \left[W\mathbf{h}_i \| W\mathbf{h}_j\right]\right)\]
  • 其中 \(W \in \mathbb{R}^{d' \times d}\) 是共享线性变换,\(\|\) 表示 concatenation,\(\mathbf{a} \in \mathbb{R}^{2d'}\) 是可学习的 attention 向量。分数 \(e_{ij}\) 衡量 node \(j\) 的 feature 对 node \(i\) 的重要性。

  • 原始分数会在所有邻居上通过 softmax 做归一化:

\[\alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})}\]
  • 这保证每个 node 的邻域 attention 权重和为 1,与第 7 章 transformer attention 一致。node 的更新特征为:
\[\mathbf{h}_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W\mathbf{h}_j\right)\]

GCN assigns fixed equal weights to all neighbours; GAT learns data-dependent attention weights

  • 与 GCN 的关键区别在于:\(\alpha_{ij}\)从数据中学习得到的,而非由 graph 结构固定决定。node 可学会聚焦最有信息的邻居,同时抑制噪声或无关邻居。

  • 注意 attention 只在已有 edge 上计算(node \(i\) 只 attention 到其邻居 \(\mathcal{N}(i)\)),而非所有 node 对。这使计算量与 edge 数成正比,而不是与 node 数平方成正比。

多头 Graph Attention

  • 与 transformer(第 7 章)一样,multi-head attention 会并行运行 \(K\) 个独立 attention 机制,每个 head 有自己的参数 \(W^k\)\(\mathbf{a}^k\)。中间层通常做拼接,最终层通常做平均:
\[\mathbf{h}_i' = \Big\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k W^k \mathbf{h}_j\right)\]
  • 每个 head 可关注邻域不同方面:一个 head 关注结构特征,另一个关注语义相似性。这与 transformer 多头机制的动机相同:不同 head 捕获不同关系类型。

  • 若有 \(K\) 个 head 且每个 head 输出维度为 \(d'\),拼接后输出维度为 \(K \times d'\)。最终层通常用平均而非拼接,以得到固定维度输出。

GATv2:修复静态 Attention

  • 原始 GAT 有个细微限制:其 attention 函数是 static(也称 ranking-based)。attention score 依赖拼接项 \([W\mathbf{h}_i \| W\mathbf{h}_j]\),但由于 attention 向量 \(\mathbf{a}\) 在拼接后才作用,可拆成两项独立部分:\(\mathbf{a}^T [W\mathbf{h}_i \| W\mathbf{h}_j] = \mathbf{a}_1^T W\mathbf{h}_i + \mathbf{a}_2^T W\mathbf{h}_j\)

  • 这意味着:对给定 node \(i\),邻居排序几乎完全由邻居特征 \(\mathbf{h}_j\) 决定(\(\mathbf{a}_1^T W\mathbf{h}_i\)\(i\) 的所有邻居是常数)。attention 排序并未真正依赖 query node 本身特征。结果是 node \(i\) 与 node \(k\) 对同一组邻居会得到相同排序,表达力受限。

  • GATv2(Brody et al., 2022)通过把非线性提前到 attention 向量之前来修复:

\[e_{ij} = \mathbf{a}^T \text{LeakyReLU}\left(W \left[\mathbf{h}_i \| \mathbf{h}_j\right]\right)\]
  • 将 LeakyReLU 移入内部后,attention score 成为 joint feature 的非线性函数,不再能分解为独立项。attention 因此变成 dynamic:邻居排序会依赖具体 query node。GATv2 在不增加计算成本的前提下,表达力严格强于 GAT。

Graph Transformer

  • 标准 message-passing GNN 受 graph topology 限制:node 只能 attention 到一跳邻居。经过 \(k\) 层后,\(k\)-hop 信息会在多次聚合中被混合稀释。这个局部瓶颈(叠加文件 3 的 over-smoothing)限制了长程依赖建模能力。

  • Graph Transformer 通过对所有 node 对施加 global self-attention 打破该瓶颈,不管两者是否有 edge。单层内每个 node 都可 attention 到所有其他 node,和标准 transformer(第 7 章)一致。

  • 基本思路是:把所有 node 当成 token,直接应用 transformer self-attention:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]
  • 其中 \(Q = XW_Q\)\(K = XW_K\)\(V = XW_V\) 是 node feature \(X\) 的 query/key/value 投影(与第 7 章完全一致)。这可视作在 fully connected graph(complete graph \(K_n\),见文件 2)上的 GNN。

  • 问题是:fully connected graph 会忽略真实 graph 结构,edge 信息(谁与谁真正相连)丢失。常见有两种补救思路:

  • Graphormer(Ying et al., 2021)通过在 attention score 中加入 bias term 注入 graph 结构:

\[A_{ij} = \frac{(\mathbf{h}_i W_Q)(W_K^T \mathbf{h}_j^T)}{\sqrt{d_k}} + b_{\text{spatial}}(i, j) + b_{\text{edge}}(i, j)\]
  • 其中空间 bias \(b_{\text{spatial}}\) 编码 node \(i,j\) 的 shortest-path 距离,edge bias \(b_{\text{edge}}\) 编码 shortest path 上的 edge feature。Graphormer 还加入 centrality encoding,把 node degree 加入输入 embedding,使模型感知 node 的结构角色。

  • GPS(General, Powerful, Scalable Graph Transformer,Rampášek et al., 2022)在每层同时结合局部 message passing 与全局 attention:

\[\mathbf{h}_i' = \text{MLP}\left(\mathbf{h}_i^{\text{MPNN}} + \mathbf{h}_i^{\text{Attention}}\right)\]
  • 每层同时做标准 GNN(提取局部结构)和 transformer(建模全局上下文),再融合结果。这样兼得两者优势:message passing 的局部结构 + attention 的长程依赖。

Positional 与 Structural Encoding

  • sequence 上的 transformer 使用 positional encoding(第 7 章)注入顺序信息。graph 没有规范顺序,因此需要 graph 专用编码。

  • Laplacian eigenvector encoding 使用 graph Laplacian(文件 2)的特征向量作为位置 feature。最小的 \(k\) 个非平凡特征向量提供了 graph 的谱嵌入:graph 上“相近”的 node 往往有相似特征向量值。通常将其与 node feature 拼接。

  • 一个细节是:Laplacian 特征向量存在符号不确定性(若 \(\mathbf{u}\) 是特征向量,则 \(-\mathbf{u}\) 也是)。模型应对符号翻转保持不敏感。常见做法包括训练时随机符号翻转作为数据增强,或学习 sign-invariant 变换。

  • random walk encoding 计算从 node \(i\) 出发、\(k\) 步后回到 \(i\) 的概率(\(k = 1,2,\ldots,K\))。这些概率编码局部结构:稠密簇中的 node 回返概率高,稀疏区域较低。落点概率为 \(p_{ii}^{(k)} = (A_{\text{rw}}^k)_{ii}\),其中 \(A_{\text{rw}} = D^{-1}A\) 是 random walk 转移矩阵。

  • degree encoding 则直接把 node degree 作为 feature。它常出乎意料地有效,因为 degree 是强结构信号:叶子 node(degree 1)、桥接 node、hub node 的行为通常不同。

  • 这些编码补足了 vanilla transformer 缺失的结构信息,使 Graph Transformer 在需要长程推理的任务上常优于标准 message-passing GNN。

Scalability

  • GNN 的根本可扩展性挑战在于:graph 可能有百万 node、十亿 edge。在全图训练时,需要在内存中存储所有 node feature 与整张 adjacency,这通常不可行。

  • GNN 的 mini-batch 训练 比图像或序列复杂,因为 node 相互连接。若直接采样一批目标 node,还需其一跳邻居(layer 1)、二跳邻居(layer 2)等,形成 neighbourhood explosion。例如 1000 个目标 node 的 batch 可能膨胀到百万级计算 node。

  • neighbourhood sampling(GraphSAGE 风格,见文件 3)通过每层每 node 固定采样邻居数来限制膨胀。若 2 层且每层采样 15 个邻居,则每个目标 node 的子图最多 \(15^2 = 225\) 个 node,与全图规模无关。

  • Cluster-GCN(Chiang et al., 2019)先用图聚类算法(如 METIS)把 graph 切成 cluster,再按 cluster 训练。cluster 内 edge 更密(多数邻居在同一 cluster),子图可保留主要结构;跨 cluster edge 通过周期性引入 cluster 间边来处理。

  • Graph Transformer 的可扩展性更难,因为 global attention 是 \(O(n^2)\)。对百万级 node graph,全注意力不可行。常见方案包括:

    • Sparse attention patterns (attend only to \(k\)-nearest nodes in the graph)
    • Linear attention approximations
    • Combining local message passing (cheap, \(O(|E|)\)) with global attention on a coarsened graph (fewer nodes)

Temporal 与 Dynamic Graph

  • 目前讨论的 graph 多是 static:node、edge、feature 都固定。但许多真实 graph 会随时间演化:社交网络新增用户、金融交易不断产生边、交通流在一天内波动、分子相互作用也会变化。

  • temporal graph 给每条 edge 加上时间戳:\((i, j, t)\) 表示 node \(i\) 在时刻 \(t\) 与 node \(j\) 发生交互。挑战是学习同时捕获 graph 结构与时间动态的表示。

  • 常见有两种范式:

  • 离散时间 dynamic graph(DTDG):将 graph 表示为快照序列 \(G_1, G_2, \ldots, G_T\)(每个时间步一张)。用 GNN 处理每个快照,再用 RNN 或 temporal attention 捕获快照间演化。优点是简单,缺点是丢失细粒度时序信息(快照间事件丢失)且需手动选择快照频率。

  • 连续时间 dynamic graph(CTDG):把事件建模为带时间戳交互流。每个事件 \((i, j, t)\) 在其发生时刻精确更新 node \(i,j\) 表示,从而保留全部时序信息。

  • Temporal Graph Network(TGN)(Rossi et al., 2020)是代表性的 CTDG 架构。每个 node 维护一个 memory state \(\mathbf{s}_i(t)\),并在其参与交互时更新:

\[\mathbf{s}_i(t^+) = \text{GRU}\left(\mathbf{s}_i(t^-), \; \mathbf{m}_i(t)\right)\]
  • 其中 \(\mathbf{m}_i(t)\) 是由交互计算得到的 message(融合两端 node feature、edge feature 和 time encoding)。GRU(第 6 章)通过选择性记忆与遗忘,让 memory 既能保留长期模式,也能适配近期事件。

  • time encoding 将“距上次交互过去多久”编码为 feature vector,类似 transformer(第 7 章)的 positional encoding。常见方式是可学习 Fourier feature:

\[\Phi(t) = \left[\cos(\omega_1 t), \sin(\omega_1 t), \ldots, \cos(\omega_d t), \sin(\omega_d t)\right]\]
  • 这使模型能细致表达时间间隔差异:“5 分钟前活跃”与“3 个月前活跃”会被映射到不同表示。

  • Temporal Graph Attention(TGAT) 对 node 的 temporal neighbourhood(近期交互集合)做 self-attention,权重同时考虑 feature 相关性(类似 GAT)与时间新近性。久远历史交互会自然被降权。

  • 应用包括:欺诈检测(金融 graph 中异常交易模式)、交通预测(基于历史流量预测拥堵)、社交动态分析(预测内容传播)、以及随时间变化的药物相互作用预测。

Coding Tasks(使用 CoLab 或 notebook)

  1. 从零实现单个 GAT attention head。计算一个 node 与其邻居之间的 attention 权重,并验证权重和为 1。

    import jax
    import jax.numpy as jnp
    
    rng = jax.random.PRNGKey(0)
    k1, k2, k3 = jax.random.split(rng, 3)
    
    n_nodes, d_in, d_out = 5, 4, 3
    
    # Random node features
    H = jax.random.normal(k1, (n_nodes, d_in))
    
    # Learnable parameters
    W = jax.random.normal(k2, (d_in, d_out)) * 0.5
    a = jax.random.normal(k3, (2 * d_out,)) * 0.5
    
    # Adjacency (node 0 connects to 1, 2, 3)
    neighbours_of_0 = [1, 2, 3]
    
    # Transform features
    Wh = H @ W  # (n_nodes, d_out)
    
    # Compute attention scores for node 0
    h_i = Wh[0]
    scores = []
    for j in neighbours_of_0:
        h_j = Wh[j]
        e_ij = jnp.dot(a, jnp.concatenate([h_i, h_j]))
        e_ij = jax.nn.leaky_relu(e_ij, negative_slope=0.2)
        scores.append(float(e_ij))
    
    scores = jnp.array(scores)
    alpha = jax.nn.softmax(scores)
    
    print(f"Raw scores: {scores}")
    print(f"Attention weights: {alpha}")
    print(f"Sum of weights: {alpha.sum():.4f}")
    
    # Weighted aggregation
    h_new = sum(alpha[k] * Wh[neighbours_of_0[k]] for k in range(len(neighbours_of_0)))
    print(f"Updated node 0 features: {h_new}")
    

  2. 对比 GCN(固定权重)与 GAT(可学习权重)的聚合。展示 GAT 可对邻居分配不同权重,而 GCN 更接近均匀处理。

    import jax
    import jax.numpy as jnp
    
    # 4 nodes: node 0 connects to 1, 2, 3
    A = jnp.array([[0,1,1,1],
                   [1,0,0,0],
                   [1,0,0,0],
                   [1,0,0,0]], dtype=float)
    
    # Features: node 1 is very relevant, node 2 is noise, node 3 is moderate
    H = jnp.array([[0.0, 0.0],   # node 0
                   [1.0, 0.0],   # node 1 (signal)
                   [0.0, 0.0],   # node 2 (noise)
                   [0.5, 0.0]])  # node 3 (moderate)
    
    # GCN: normalised adjacency weights
    A_hat = A + jnp.eye(4)
    D_inv = jnp.diag(1.0 / A_hat.sum(axis=1))
    gcn_weights = (D_inv @ A_hat)[0]  # weights for node 0
    print(f"GCN weights for node 0: {gcn_weights}")
    print("  → All neighbours get roughly equal weight")
    
    # GAT: learned attention (simulated)
    # Suppose the attention mechanism learns to focus on node 1
    gat_weights = jnp.array([0.1, 0.7, 0.05, 0.15])  # learned
    print(f"\nGAT weights for node 0: {gat_weights}")
    print("  → Node 1 (informative) gets most attention")
    
    gcn_output = gcn_weights @ H
    gat_output = gat_weights @ H
    print(f"\nGCN output: {gcn_output}  (diluted by noise)")
    print(f"GAT output: {gat_output}  (focused on signal)")
    

  3. 演示 positional encoding 的作用。计算 graph 的 Laplacian eigenvector encoding,并展示结构相似 node 会得到相似编码。

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # Barbell graph: two cliques connected by a bridge
    n = 10
    A = jnp.zeros((n, n))
    # Clique 1: nodes 0-4
    for i in range(5):
        for j in range(i+1, 5):
            A = A.at[i,j].set(1).at[j,i].set(1)
    # Clique 2: nodes 5-9
    for i in range(5, 10):
        for j in range(i+1, 10):
            A = A.at[i,j].set(1).at[j,i].set(1)
    # Bridge
    A = A.at[4,5].set(1).at[5,4].set(1)
    
    D = jnp.diag(A.sum(axis=1))
    L = D - A
    eigenvalues, eigenvectors = jnp.linalg.eigh(L)
    
    # Use first 3 non-trivial eigenvectors as positional encoding
    pe = eigenvectors[:, 1:4]
    
    print("Laplacian Positional Encodings:")
    for i in range(n):
        group = "Clique 1" if i < 5 else "Clique 2"
        bridge = " (bridge)" if i in [4, 5] else ""
        print(f"  Node {i} ({group}{bridge}): {pe[i]}")
    
    plt.scatter(pe[:5, 0], pe[:5, 1], c="#3498db", s=80, label="Clique 1")
    plt.scatter(pe[5:, 0], pe[5:, 1], c="#e74c3c", s=80, label="Clique 2")
    plt.scatter(pe[[4,5], 0], pe[[4,5], 1], c="black", s=120, marker="*",
                label="Bridge nodes", zorder=5)
    plt.legend(); plt.grid(True)
    plt.title("Laplacian Eigenvector Positional Encodings")
    plt.xlabel("Eigenvector 1"); plt.ylabel("Eigenvector 2")
    plt.show()