Graph Attention Networks¶
Graph attention network 用可学习、数据依赖的权重替代了统一的邻居聚合。本文件覆盖 GAT、多头 graph attention、GATv2、Graph Transformer、位置与结构编码以及可扩展性。
-
在 GCN(文件 3)中,每个 node 都用由 graph 结构(normalised adjacency)决定的固定权重来聚合邻居 feature。若一个 node 有三个邻居,每个邻居大约分到 \(\approx 1/3\) 权重。但现实中邻居的重要性并不相同:来自紧密协作者的消息通常应比来自远关系节点的消息更重要。
-
Graph Attention Network 通过学习“该 attention 哪些邻居”来解决这个问题,机制与第 7 章 transformer 的 attention 相同。与固定的结构权重不同,每个 node 会对邻居动态计算基于内容的 attention score。
GAT:Graph Attention Network¶
- GAT(Veličković et al., 2018)在 node 与其邻居之间计算 attention 系数。对 node \(i\) 及其邻居 \(j\):
-
其中 \(W \in \mathbb{R}^{d' \times d}\) 是共享线性变换,\(\|\) 表示 concatenation,\(\mathbf{a} \in \mathbb{R}^{2d'}\) 是可学习的 attention 向量。分数 \(e_{ij}\) 衡量 node \(j\) 的 feature 对 node \(i\) 的重要性。
-
原始分数会在所有邻居上通过 softmax 做归一化:
- 这保证每个 node 的邻域 attention 权重和为 1,与第 7 章 transformer attention 一致。node 的更新特征为:
-
与 GCN 的关键区别在于:\(\alpha_{ij}\) 是从数据中学习得到的,而非由 graph 结构固定决定。node 可学会聚焦最有信息的邻居,同时抑制噪声或无关邻居。
-
注意 attention 只在已有 edge 上计算(node \(i\) 只 attention 到其邻居 \(\mathcal{N}(i)\)),而非所有 node 对。这使计算量与 edge 数成正比,而不是与 node 数平方成正比。
多头 Graph Attention¶
- 与 transformer(第 7 章)一样,multi-head attention 会并行运行 \(K\) 个独立 attention 机制,每个 head 有自己的参数 \(W^k\) 和 \(\mathbf{a}^k\)。中间层通常做拼接,最终层通常做平均:
-
每个 head 可关注邻域不同方面:一个 head 关注结构特征,另一个关注语义相似性。这与 transformer 多头机制的动机相同:不同 head 捕获不同关系类型。
-
若有 \(K\) 个 head 且每个 head 输出维度为 \(d'\),拼接后输出维度为 \(K \times d'\)。最终层通常用平均而非拼接,以得到固定维度输出。
GATv2:修复静态 Attention¶
-
原始 GAT 有个细微限制:其 attention 函数是 static(也称 ranking-based)。attention score 依赖拼接项 \([W\mathbf{h}_i \| W\mathbf{h}_j]\),但由于 attention 向量 \(\mathbf{a}\) 在拼接后才作用,可拆成两项独立部分:\(\mathbf{a}^T [W\mathbf{h}_i \| W\mathbf{h}_j] = \mathbf{a}_1^T W\mathbf{h}_i + \mathbf{a}_2^T W\mathbf{h}_j\)。
-
这意味着:对给定 node \(i\),邻居排序几乎完全由邻居特征 \(\mathbf{h}_j\) 决定(\(\mathbf{a}_1^T W\mathbf{h}_i\) 对 \(i\) 的所有邻居是常数)。attention 排序并未真正依赖 query node 本身特征。结果是 node \(i\) 与 node \(k\) 对同一组邻居会得到相同排序,表达力受限。
-
GATv2(Brody et al., 2022)通过把非线性提前到 attention 向量之前来修复:
- 将 LeakyReLU 移入内部后,attention score 成为 joint feature 的非线性函数,不再能分解为独立项。attention 因此变成 dynamic:邻居排序会依赖具体 query node。GATv2 在不增加计算成本的前提下,表达力严格强于 GAT。
Graph Transformer¶
-
标准 message-passing GNN 受 graph topology 限制:node 只能 attention 到一跳邻居。经过 \(k\) 层后,\(k\)-hop 信息会在多次聚合中被混合稀释。这个局部瓶颈(叠加文件 3 的 over-smoothing)限制了长程依赖建模能力。
-
Graph Transformer 通过对所有 node 对施加 global self-attention 打破该瓶颈,不管两者是否有 edge。单层内每个 node 都可 attention 到所有其他 node,和标准 transformer(第 7 章)一致。
-
基本思路是:把所有 node 当成 token,直接应用 transformer self-attention:
-
其中 \(Q = XW_Q\)、\(K = XW_K\)、\(V = XW_V\) 是 node feature \(X\) 的 query/key/value 投影(与第 7 章完全一致)。这可视作在 fully connected graph(complete graph \(K_n\),见文件 2)上的 GNN。
-
问题是:fully connected graph 会忽略真实 graph 结构,edge 信息(谁与谁真正相连)丢失。常见有两种补救思路:
-
Graphormer(Ying et al., 2021)通过在 attention score 中加入 bias term 注入 graph 结构:
-
其中空间 bias \(b_{\text{spatial}}\) 编码 node \(i,j\) 的 shortest-path 距离,edge bias \(b_{\text{edge}}\) 编码 shortest path 上的 edge feature。Graphormer 还加入 centrality encoding,把 node degree 加入输入 embedding,使模型感知 node 的结构角色。
-
GPS(General, Powerful, Scalable Graph Transformer,Rampášek et al., 2022)在每层同时结合局部 message passing 与全局 attention:
- 每层同时做标准 GNN(提取局部结构)和 transformer(建模全局上下文),再融合结果。这样兼得两者优势:message passing 的局部结构 + attention 的长程依赖。
Positional 与 Structural Encoding¶
-
sequence 上的 transformer 使用 positional encoding(第 7 章)注入顺序信息。graph 没有规范顺序,因此需要 graph 专用编码。
-
Laplacian eigenvector encoding 使用 graph Laplacian(文件 2)的特征向量作为位置 feature。最小的 \(k\) 个非平凡特征向量提供了 graph 的谱嵌入:graph 上“相近”的 node 往往有相似特征向量值。通常将其与 node feature 拼接。
-
一个细节是:Laplacian 特征向量存在符号不确定性(若 \(\mathbf{u}\) 是特征向量,则 \(-\mathbf{u}\) 也是)。模型应对符号翻转保持不敏感。常见做法包括训练时随机符号翻转作为数据增强,或学习 sign-invariant 变换。
-
random walk encoding 计算从 node \(i\) 出发、\(k\) 步后回到 \(i\) 的概率(\(k = 1,2,\ldots,K\))。这些概率编码局部结构:稠密簇中的 node 回返概率高,稀疏区域较低。落点概率为 \(p_{ii}^{(k)} = (A_{\text{rw}}^k)_{ii}\),其中 \(A_{\text{rw}} = D^{-1}A\) 是 random walk 转移矩阵。
-
degree encoding 则直接把 node degree 作为 feature。它常出乎意料地有效,因为 degree 是强结构信号:叶子 node(degree 1)、桥接 node、hub node 的行为通常不同。
-
这些编码补足了 vanilla transformer 缺失的结构信息,使 Graph Transformer 在需要长程推理的任务上常优于标准 message-passing GNN。
Scalability¶
-
GNN 的根本可扩展性挑战在于:graph 可能有百万 node、十亿 edge。在全图训练时,需要在内存中存储所有 node feature 与整张 adjacency,这通常不可行。
-
GNN 的 mini-batch 训练 比图像或序列复杂,因为 node 相互连接。若直接采样一批目标 node,还需其一跳邻居(layer 1)、二跳邻居(layer 2)等,形成 neighbourhood explosion。例如 1000 个目标 node 的 batch 可能膨胀到百万级计算 node。
-
neighbourhood sampling(GraphSAGE 风格,见文件 3)通过每层每 node 固定采样邻居数来限制膨胀。若 2 层且每层采样 15 个邻居,则每个目标 node 的子图最多 \(15^2 = 225\) 个 node,与全图规模无关。
-
Cluster-GCN(Chiang et al., 2019)先用图聚类算法(如 METIS)把 graph 切成 cluster,再按 cluster 训练。cluster 内 edge 更密(多数邻居在同一 cluster),子图可保留主要结构;跨 cluster edge 通过周期性引入 cluster 间边来处理。
-
Graph Transformer 的可扩展性更难,因为 global attention 是 \(O(n^2)\)。对百万级 node graph,全注意力不可行。常见方案包括:
- Sparse attention patterns (attend only to \(k\)-nearest nodes in the graph)
- Linear attention approximations
- Combining local message passing (cheap, \(O(|E|)\)) with global attention on a coarsened graph (fewer nodes)
Temporal 与 Dynamic Graph¶
-
目前讨论的 graph 多是 static:node、edge、feature 都固定。但许多真实 graph 会随时间演化:社交网络新增用户、金融交易不断产生边、交通流在一天内波动、分子相互作用也会变化。
-
temporal graph 给每条 edge 加上时间戳:\((i, j, t)\) 表示 node \(i\) 在时刻 \(t\) 与 node \(j\) 发生交互。挑战是学习同时捕获 graph 结构与时间动态的表示。
-
常见有两种范式:
-
离散时间 dynamic graph(DTDG):将 graph 表示为快照序列 \(G_1, G_2, \ldots, G_T\)(每个时间步一张)。用 GNN 处理每个快照,再用 RNN 或 temporal attention 捕获快照间演化。优点是简单,缺点是丢失细粒度时序信息(快照间事件丢失)且需手动选择快照频率。
-
连续时间 dynamic graph(CTDG):把事件建模为带时间戳交互流。每个事件 \((i, j, t)\) 在其发生时刻精确更新 node \(i,j\) 表示,从而保留全部时序信息。
-
Temporal Graph Network(TGN)(Rossi et al., 2020)是代表性的 CTDG 架构。每个 node 维护一个 memory state \(\mathbf{s}_i(t)\),并在其参与交互时更新:
-
其中 \(\mathbf{m}_i(t)\) 是由交互计算得到的 message(融合两端 node feature、edge feature 和 time encoding)。GRU(第 6 章)通过选择性记忆与遗忘,让 memory 既能保留长期模式,也能适配近期事件。
-
time encoding 将“距上次交互过去多久”编码为 feature vector,类似 transformer(第 7 章)的 positional encoding。常见方式是可学习 Fourier feature:
-
这使模型能细致表达时间间隔差异:“5 分钟前活跃”与“3 个月前活跃”会被映射到不同表示。
-
Temporal Graph Attention(TGAT) 对 node 的 temporal neighbourhood(近期交互集合)做 self-attention,权重同时考虑 feature 相关性(类似 GAT)与时间新近性。久远历史交互会自然被降权。
-
应用包括:欺诈检测(金融 graph 中异常交易模式)、交通预测(基于历史流量预测拥堵)、社交动态分析(预测内容传播)、以及随时间变化的药物相互作用预测。
Coding Tasks(使用 CoLab 或 notebook)¶
-
从零实现单个 GAT attention head。计算一个 node 与其邻居之间的 attention 权重,并验证权重和为 1。
import jax import jax.numpy as jnp rng = jax.random.PRNGKey(0) k1, k2, k3 = jax.random.split(rng, 3) n_nodes, d_in, d_out = 5, 4, 3 # Random node features H = jax.random.normal(k1, (n_nodes, d_in)) # Learnable parameters W = jax.random.normal(k2, (d_in, d_out)) * 0.5 a = jax.random.normal(k3, (2 * d_out,)) * 0.5 # Adjacency (node 0 connects to 1, 2, 3) neighbours_of_0 = [1, 2, 3] # Transform features Wh = H @ W # (n_nodes, d_out) # Compute attention scores for node 0 h_i = Wh[0] scores = [] for j in neighbours_of_0: h_j = Wh[j] e_ij = jnp.dot(a, jnp.concatenate([h_i, h_j])) e_ij = jax.nn.leaky_relu(e_ij, negative_slope=0.2) scores.append(float(e_ij)) scores = jnp.array(scores) alpha = jax.nn.softmax(scores) print(f"Raw scores: {scores}") print(f"Attention weights: {alpha}") print(f"Sum of weights: {alpha.sum():.4f}") # Weighted aggregation h_new = sum(alpha[k] * Wh[neighbours_of_0[k]] for k in range(len(neighbours_of_0))) print(f"Updated node 0 features: {h_new}") -
对比 GCN(固定权重)与 GAT(可学习权重)的聚合。展示 GAT 可对邻居分配不同权重,而 GCN 更接近均匀处理。
import jax import jax.numpy as jnp # 4 nodes: node 0 connects to 1, 2, 3 A = jnp.array([[0,1,1,1], [1,0,0,0], [1,0,0,0], [1,0,0,0]], dtype=float) # Features: node 1 is very relevant, node 2 is noise, node 3 is moderate H = jnp.array([[0.0, 0.0], # node 0 [1.0, 0.0], # node 1 (signal) [0.0, 0.0], # node 2 (noise) [0.5, 0.0]]) # node 3 (moderate) # GCN: normalised adjacency weights A_hat = A + jnp.eye(4) D_inv = jnp.diag(1.0 / A_hat.sum(axis=1)) gcn_weights = (D_inv @ A_hat)[0] # weights for node 0 print(f"GCN weights for node 0: {gcn_weights}") print(" → All neighbours get roughly equal weight") # GAT: learned attention (simulated) # Suppose the attention mechanism learns to focus on node 1 gat_weights = jnp.array([0.1, 0.7, 0.05, 0.15]) # learned print(f"\nGAT weights for node 0: {gat_weights}") print(" → Node 1 (informative) gets most attention") gcn_output = gcn_weights @ H gat_output = gat_weights @ H print(f"\nGCN output: {gcn_output} (diluted by noise)") print(f"GAT output: {gat_output} (focused on signal)") -
演示 positional encoding 的作用。计算 graph 的 Laplacian eigenvector encoding,并展示结构相似 node 会得到相似编码。
import jax.numpy as jnp import matplotlib.pyplot as plt # Barbell graph: two cliques connected by a bridge n = 10 A = jnp.zeros((n, n)) # Clique 1: nodes 0-4 for i in range(5): for j in range(i+1, 5): A = A.at[i,j].set(1).at[j,i].set(1) # Clique 2: nodes 5-9 for i in range(5, 10): for j in range(i+1, 10): A = A.at[i,j].set(1).at[j,i].set(1) # Bridge A = A.at[4,5].set(1).at[5,4].set(1) D = jnp.diag(A.sum(axis=1)) L = D - A eigenvalues, eigenvectors = jnp.linalg.eigh(L) # Use first 3 non-trivial eigenvectors as positional encoding pe = eigenvectors[:, 1:4] print("Laplacian Positional Encodings:") for i in range(n): group = "Clique 1" if i < 5 else "Clique 2" bridge = " (bridge)" if i in [4, 5] else "" print(f" Node {i} ({group}{bridge}): {pe[i]}") plt.scatter(pe[:5, 0], pe[:5, 1], c="#3498db", s=80, label="Clique 1") plt.scatter(pe[5:, 0], pe[5:, 1], c="#e74c3c", s=80, label="Clique 2") plt.scatter(pe[[4,5], 0], pe[[4,5], 1], c="black", s=120, marker="*", label="Bridge nodes", zorder=5) plt.legend(); plt.grid(True) plt.title("Laplacian Eigenvector Positional Encodings") plt.xlabel("Eigenvector 1"); plt.ylabel("Eigenvector 2") plt.show()