Skip to content

3D Graph Networks

3D graph network 将 GNN 扩展到具有空间几何的数据,在这类数据中必须正确处理旋转与平移。本文件覆盖 geometric graph、SE(3)/E(n)-equivariance、SchNet、DimeNet、EGNN、tensor field network,以及其在分子性质预测、蛋白质结构、材料科学与药物发现中的应用——这些架构让模型能够从 3D 物理世界中学习。

  • 文件 3 与文件 4 的 GNN 主要处理抽象 graph:node 有 feature,edge 表示连通关系,但没有显式 3D 空间概念。社交网络 graph 就是典型无几何结构的例子。然而,GNN 最有影响力的应用之一恰恰来自物理 3D 空间数据:分子、蛋白质、晶体、point cloud。对这类任务,node 的空间位置信息至关重要,而抽象 GNN 往往忽略这一点。

  • 难点在于 3D 数据具有几何 symmetry(见文件 1):旋转分子不应改变其性质,整体平移也不应改变。3D GNN 必须尊重这些 symmetry。若模型预测的能量会因旋转而变化,那在物理上就是错误的。

Geometric Graph

  • geometric graph 是嵌入在 3D 空间中的 graph。每个 node \(i\) 除了 feature 向量 \(\mathbf{h}_i\) 外,还具有位置 \(\mathbf{r}_i \in \mathbb{R}^3\)。edge 可由空间邻近关系定义(例如连接距离小于 \(r_{\text{cut}}\) 的 node),而不一定来自显式化学键。

  • 对分子来说,geometric graph 的 node 是原子(feature 可包含元素类型、电荷等),edge 是化学键。3D 位置 \(\mathbf{r}_i\) 即原子坐标,可来自量子力学计算或实验测量(X-ray crystallography、cryo-EM)。

  • 对 point cloud(来自 LiDAR 或 3D 扫描,见第 8 章与第 11 章),每个点都是 node,带位置及可选 feature(颜色、强度等)。edge 通常连接空间近邻点,形成 k-nearest-neighbour(kNN)graph 或 radius graph。

  • 用于 message passing 的关键几何量包括:

    • 原子间距离\(d_{ij} = \|\mathbf{r}_i - \mathbf{r}_j\|\)。distance 对旋转和平移都 invariant。若两分子的所有原子间距离一致,它们形状相同(无论朝向如何)。

    • 键角:在 node \(i\) 处,向量 \(\mathbf{r}_j - \mathbf{r}_i\)\(\mathbf{r}_k - \mathbf{r}_i\) 的夹角 \(\theta_{ijk}\)。angle 能提供超越 pairwise distance 的局部几何信息。

    • 二面角(扭转角):平面 \((i, j, k)\) 与平面 \((j, k, l)\) 的夹角 \(\phi_{ijkl}\)。dihedral 描述 3D 扭转结构,对蛋白质骨架几何尤为关键。

    • 相对位置向量\(\mathbf{r}_{ij} = \mathbf{r}_j - \mathbf{r}_i\)。它对平移 invariant,但对旋转 invariant。若使用它,就需要 equivariant(而非仅 invariant)架构。

SE(3) 与 E(n) Equivariance

  • 3D 物理数据对应的 symmetry group 是 Euclidean group \(E(3)\),包含全部旋转、反射和平移。其子群 \(SE(3)\)(Special Euclidean)包含旋转和平移,但不包含反射。

  • 一个 3D GNN 应满足:

    • 对标量输出(能量、结合亲和力)保持平移 invariant:所有原子同时平移同一向量,不应改变预测。
    • 对标量输出保持旋转 invariant:旋转分子不应改变其能量。
    • 对向量/张量输出(力、偶极矩)保持旋转 equivariant:旋转分子后,预测向量应按同样旋转变换。

SE(3)-equivariance: rotating a molecule leaves scalar predictions (energy) unchanged but rotates vector predictions (forces) correspondingly

  • 形式化地,若标量预测为 \(f\)、旋转为 \(R \in SO(3)\)
\[f(R\mathbf{r}_1, R\mathbf{r}_2, \ldots) = f(\mathbf{r}_1, \mathbf{r}_2, \ldots) \quad \text{(invariance)}\]
  • 若向量预测为 \(\mathbf{F}\)
\[\mathbf{F}(R\mathbf{r}_1, R\mathbf{r}_2, \ldots) = R \cdot \mathbf{F}(\mathbf{r}_1, \mathbf{r}_2, \ldots) \quad \text{(equivariance)}\]
  • 这些约束与文件 1 的 invariance/equivariance 框架完全一致,只是这里具体应用到 3D 旋转与平移群。

  • 架构设计通常有两条路线:

    1. invariant 架构:仅使用 invariant 几何特征(distance、angle)作为 message passing 输入。内部表示是标量(invariant)。优点是简单高效;缺点是难以自然输出向量。
    2. equivariant 架构:在网络中持续维护向量(及更高阶张量)表示,保证每层都 equivariant。表达力更强,可自然预测向量/张量,但实现更复杂。

SchNet:基于 Distance 的 Message Passing

  • SchNet(Schütt et al., 2017)是经典的 invariant 3D GNN。其关键创新是 continuous filter convolution:不像传统分子 GNN 用离散 edge 类型(如键类型),SchNet 直接根据原子间 distance 生成 message filter。

  • distance \(d_{ij}\) 首先通过 radial basis function(RBF) 展开为特征向量:

\[\text{RBF}(d_{ij}) = \left[\exp\left(-\gamma_1 (d_{ij} - \mu_1)^2\right), \ldots, \exp\left(-\gamma_K (d_{ij} - \mu_K)^2\right)\right]\]
  • 每个基函数都是以 \(\mu_k\) 为中心、宽度为 \(\gamma_k\) 的 Gaussian。可把它理解为 distance 的可学习 positional encoding:把连续 distance 映射到高维空间,便于学习距离相关相互作用。\(\mu_k\) 常均匀覆盖从 0 到 cutoff radius 的区间。

  • SchNet 从 node \(j\) 到 node \(i\) 的 message 为:

\[\mathbf{m}_{j \to i} = \mathbf{h}_j \odot W_{\text{filter}}(\text{RBF}(d_{ij}))\]
  • 其中 \(W_{\text{filter}}\) 是把 RBF 特征映射成 filter 向量的 MLP,\(\odot\) 是逐元素乘(Hadamard product,见第 2 章)。因为 filter 依赖 distance,近邻原子与远邻原子的交互可不同。逐元素乘也可视为 gating(第 6 章):由 distance 决定每个特征维度通过多少信息。

  • SchNet 只使用 distance(invariant),因此模型天然对旋转和平移 invariant,不需要额外 symmetry 处理。

DimeNet 与 SphereNet:Angle 与 Dihedral

  • 仅靠 distance 不能完整确定 3D 结构。不同分子构象可能有相同 pairwise distance 但 angle 不同(即“distance geometry ambiguity”问题)。DimeNet(Gasteiger et al., 2020)在 message passing 中显式引入 bond angle

  • DimeNet 采用 directional message passing:message 沿有向 edge 传播,且边 \((j \to i)\) 上的 message 会受边 \((k \to j)\)\((j \to i)\) 之间 angle 的影响:

\[\mathbf{m}_{kj \to ji} = f\left(\mathbf{m}_{kj}, d_{ji}, \theta_{kji}\right)\]
  • 其中 angle \(\theta_{kji}\) 会通过球 Bessel 函数与 spherical harmonics 展开(球面角信息的自然基,可类比 RBF 对 distance 的作用)。这样模型在保持 invariance 的同时获得方向信息。

  • SphereNet(Liu et al., 2022)更进一步加入 dihedral angle \(\phi_{lkji}\),捕获完整的 3D 扭转结构。可理解为逐层增强几何分辨率:

    • distance:捕获成对邻近关系
    • angle:捕获局部几何(弯曲/线性)
    • dihedral:捕获 3D 扭转(对蛋白质骨架、药物结合尤关键)
  • 几何分辨率越高,计算开销越大(distance 为 \(O(|E|)\),angle 为 \(O(|E| \cdot k)\),dihedral 为 \(O(|E| \cdot k^2)\),其中 \(k\) 为平均 degree)。

E(n) Equivariant GNN(EGNN)

  • EGNN(Satorras et al., 2021)走 equivariant 路线:不仅更新 node feature,也在每层更新 node position,并在整个过程中保持 equivariance。

  • EGNN 对 node \(i\) 的更新为:

\[\mathbf{m}_{ij} = \phi_e\left(\mathbf{h}_i, \mathbf{h}_j, d_{ij}^2, a_{ij}\right)\]
\[\mathbf{r}_i' = \mathbf{r}_i + C \sum_{j \neq i} (\mathbf{r}_i - \mathbf{r}_j) \cdot \phi_r(\mathbf{m}_{ij})\]
\[\mathbf{h}_i' = \phi_h\left(\mathbf{h}_i, \sum_j \mathbf{m}_{ij}\right)\]
  • 核心在位置更新:node position 通过相对位置向量 \((\mathbf{r}_i - \mathbf{r}_j)\) 的加权和来修正,权重来自 \(\phi_r\),而 \(\phi_r\) 仅依赖 invariant 量(feature 与 distance)。这种构造可证明保持 equivariance:若输入位置整体乘以旋转 \(R\),输出位置也会整体乘以同一 \(R\)

  • EGNN 的优雅之处在于:无需显式使用 spherical harmonics 或不可约表示即可实现 equivariance。相对位置向量提供方向信息,invariant 的 message 函数决定如何使用这些方向信息。

  • 但其简洁性也有取舍:EGNN 主要处理一阶向量表示,若要原生表示更高阶张量(如四极矩、应力张量)需额外扩展。

Tensor Field Network 与高阶表示

  • Tensor Field Network(Thomas et al., 2018)及其后续(SE(3)-TransformerMACEEquiformer)使用旋转群不可约表示理论的完整工具链来构建 equivariant layer。

  • 在表示理论中(可联系第 2 章线性代数),3D 旋转可分解为由整数阶 \(\ell\) 表征的不可约分量:

    • \(\ell = 0\):标量(1 个分量,invariant),如能量、电荷。
    • \(\ell = 1\):向量(3 个分量,按位置向量方式旋转),如力、偶极矩。
    • \(\ell = 2\):二阶对称无迹张量(5 个分量),如四极矩、应力张量。
    • 更高 \(\ell\):表达更复杂角向结构。
  • 这些称为 spherical tensor,在旋转 \(R\) 下通过 Wigner-D 矩阵 \(D^\ell(R)\) 变换:标量不变,向量按 \(R\) 旋转,二阶张量按更复杂矩阵变换。

  • 基于 spherical tensor 的 equivariant message passing 使用 Clebsch-Gordan tensor product 组合不同阶特征:

\[(\mathbf{f}^{\ell_1} \otimes \mathbf{f}^{\ell_2})^{\ell_{\text{out}}} = \sum_{m_1, m_2} C^{\ell_{\text{out}}, m_{\text{out}}}_{\ell_1, m_1, \ell_2, m_2} \cdot f^{\ell_1}_{m_1} \cdot f^{\ell_2}_{m_2}\]
  • Clebsch-Gordan 系数 \(C\) 是固定数学常数,保证该 tensor product 保持 equivariant。它可看作 SO(3)-equivariant 版本的“矩阵乘法”。

  • MACE(Batatia et al., 2022)通过更高阶 message(多邻居特征乘积)在较少 message-passing 层中获得高精度。它按体阶构造交互:2-body 来自 distance、3-body 来自 angle、many-body 来自 tensor product,从而高效捕获复杂原子相互作用。

  • Equiformer(Liao & Smidt, 2023)将 equivariant spherical tensor feature 与 transformer attention(文件 4)结合,形成 SE(3)-equivariant Graph Transformer。attention score 由 invariant feature 计算,而 value 聚合在 equivariant tensor feature 上执行。

应用场景

  • 分子性质预测:给定分子 3D 结构,预测能量、力、偶极矩、HOMO-LUMO gap、毒性、溶解度等。这是 3D GNN 最成熟的应用方向。基于量子化学数据集(QM9、OC20)训练的模型在多项性质上达到接近化学精度,可用于大规模虚拟筛选。

  • 分子动力学加速:用量子力学(如 DFT)计算原子间力代价极高(对 \(n\) 电子常见 \(O(n^3)\))。训练好的 3D GNN 可在分子动力学中替代 DFT 力计算,在接近 DFT 精度下带来 \(10^3\)\(10^6\) 级加速,支持更大体系与更长时间尺度模拟。

  • 蛋白质结构建模:蛋白质是氨基酸链折叠形成的复杂 3D 结构。其骨架可表示为 geometric graph:node 是残基,edge 连接空间近邻残基。3D GNN 可用于蛋白功能预测、结合位点识别、蛋白设计(逆折叠:给定目标结构预测氨基酸序列)。AlphaFold 也融合了几何与 graph 推理思想。

  • 材料科学与催化:晶体材料具有周期性 3D 结构。GNN 可建模重复单胞并预测带隙、形成能、力学强度等性质。Open Catalyst Project(OC20/OC22)基准即评测 GNN 对催化表面吸附能预测能力,帮助加速新能源催化剂发现。

  • 药物发现:3D GNN 用于预测药物分子与靶蛋白的结合方式。结合亲和力依赖药物与蛋白结合口袋的 3D 形状互补与化学相互作用。像 DiffDock 这样的模型将 equivariant GNN 与 diffusion(第 8 章)结合,用于预测 binding pose(药物在蛋白口袋中的 3D 位姿)。

Graph Generation

  • 前述架构都在分析已有 graph。graph generation 则是生成新 graph:例如设计目标性质分子、合成测试用社交网络、提出新蛋白结构。这是 graph-level 预测的生成式对应任务。

  • 难点在于 graph 是离散的、可变长的、组合空间巨大的。生成一个 graph 意味着要决定 node 数量、每个 node 的 feature、以及任意 node 对是否连接。可选 graph 数随 node 数超指数增长。

  • autoregressive generation 按 node(或 edge)逐步构造 graph。GraphRNN(You et al., 2018)即按序生成:RNN 维护状态,每步生成一个新 node,并决定它与已有 node 的连接。虽然这给无序 graph 强加了序列顺序,但使用 BFS 顺序通常更稳定,因为近期生成的 node 更相关。

  • VAE-based generation 先用 GNN encoder 将 graph 编到连续 latent 空间,再从采样 latent 解码 graph。GraphVAE 一次性生成概率 adjacency matrix \(\hat{A} \in [0, 1]^{n \times n}\),但其复杂度是 \(O(n^2)\),且输出偏稠密,需要阈值化。latent 空间带来平滑插值能力:在两分子 embedding 之间移动可生成化学上可行的中间结构。

  • diffusion-based generation 把第 8 章 diffusion 框架用于 graph。前向过程逐步给 node feature 与 edge 结构加噪;反向过程学习去噪,从噪声生成有效 graph。DiGress(Vignac et al., 2023)对 node 类型和 edge 类型都做离散 diffusion,天然契合 graph 数据的类别属性。

  • 分子生成而言,核心约束是化学有效性:生成分子必须满足价键规则(如碳常成 4 键、氧常成 2 键等)。Junction Tree VAE(JT-VAE) 把分子分解为合法子结构(环、链、官能团)再组合生成,可在构造层面保证有效性。

  • goal-directed generation 会针对指定目标优化:例如生成对某蛋白结合亲和力高、毒性低、溶解性好的分子。通常把 graph generation 与性质预测(3D GNN 作为评估器)组成闭环:生成 → 评估 → 精化。可用强化学习(第 6 章)或贝叶斯优化引导搜索。

  • DiffDock(Corso et al., 2023)使用 SE(3)-equivariant diffusion 预测药物分子在蛋白结合口袋中的 dock 方式。模型从随机位姿出发逐步去噪,生成 3D binding pose(药物相对蛋白的位置与朝向),把本文件的 3D equivariant network 与第 8 章 diffusion 框架结合在一起。

Coding Tasks(使用 CoLab 或 notebook)

  1. 用原子间 distance 构建一个简单的 invariant 3D message-passing layer。将其应用到小分子(water: H-O-H),并验证输出对旋转保持 invariant。

    import jax
    import jax.numpy as jnp
    
    # Water molecule: O at origin, two H atoms
    positions = jnp.array([[0.0, 0.0, 0.0],     # O
                            [0.96, 0.0, 0.0],    # H1
                            [-0.24, 0.93, 0.0]])  # H2
    
    # Node features: [atomic number]
    features = jnp.array([[8.0], [1.0], [1.0]])
    
    # Compute pairwise distances (invariant)
    def pairwise_distances(pos):
        diff = pos[:, None, :] - pos[None, :, :]
        return jnp.sqrt(jnp.sum(diff**2, axis=-1) + 1e-8)
    
    # Simple distance-based message passing
    def invariant_message_pass(features, positions):
        dists = pairwise_distances(positions)
        # RBF expansion with 4 centres
        centres = jnp.array([0.5, 1.0, 1.5, 2.0])
        rbf = jnp.exp(-5.0 * (dists[:, :, None] - centres[None, None, :]) ** 2)
    
        # Message: features weighted by distance-dependent filter
        messages = jnp.einsum("ij,jd->id", rbf.sum(axis=-1), features)
        return messages
    
    output1 = invariant_message_pass(features, positions)
    
    # Rotate the molecule by 90 degrees around z-axis
    R = jnp.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]], dtype=float)
    rotated_positions = (R @ positions.T).T
    
    output2 = invariant_message_pass(features, rotated_positions)
    
    print(f"Original output:\n{output1}")
    print(f"\nRotated output:\n{output2}")
    print(f"\nInvariant: {jnp.allclose(output1, output2, atol=1e-5)}")
    

  2. 计算三个原子构成的 bond angle,并验证该角度对旋转 invariant。

    import jax.numpy as jnp
    
    def bond_angle(r_i, r_j, r_k):
        """Angle at node j between edges j->i and j->k."""
        v1 = r_i - r_j
        v2 = r_k - r_j
        cos_angle = jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2))
        return jnp.arccos(jnp.clip(cos_angle, -1, 1))
    
    # Three atoms
    r1 = jnp.array([1.0, 0.0, 0.0])
    r2 = jnp.array([0.0, 0.0, 0.0])
    r3 = jnp.array([0.0, 1.0, 0.0])
    
    angle_original = bond_angle(r1, r2, r3)
    print(f"Original angle: {jnp.degrees(angle_original):.1f}°")
    
    # Apply random rotation
    R = jnp.array([[0.36, 0.48, -0.80],
                   [-0.80, 0.60, 0.00],
                   [0.48, 0.64, 0.60]])
    r1_rot, r2_rot, r3_rot = R @ r1, R @ r2, R @ r3
    
    angle_rotated = bond_angle(r1_rot, r2_rot, r3_rot)
    print(f"Rotated angle:  {jnp.degrees(angle_rotated):.1f}°")
    print(f"Invariant: {jnp.allclose(angle_original, angle_rotated, atol=1e-4)}")
    

  3. 演示 EGNN 风格的 equivariant 位置更新。使用按 distance 加权的相对位置向量更新 node 位置,并验证 equivariance。

    import jax
    import jax.numpy as jnp
    
    def egnn_position_update(positions, features):
        """Simple EGNN-style equivariant position update."""
        n = positions.shape[0]
        new_positions = jnp.zeros_like(positions)
    
        for i in range(n):
            shift = jnp.zeros(3)
            for j in range(n):
                if i != j:
                    r_ij = positions[i] - positions[j]
                    d_ij = jnp.linalg.norm(r_ij)
                    # Weight based on distance (simple: inverse distance)
                    weight = 1.0 / (d_ij + 1.0)
                    # Scale by feature similarity
                    feat_sim = jnp.dot(features[i], features[j])
                    shift = shift + weight * feat_sim * r_ij
            new_positions = new_positions.at[i].set(positions[i] + 0.1 * shift)
    
        return new_positions
    
    # 3 atoms
    pos = jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
    feat = jnp.array([[1.0, 0.5], [0.5, 1.0], [0.8, 0.3]])
    
    # Update positions
    pos_new = egnn_position_update(pos, feat)
    
    # Now rotate input, update, and check if output is rotated consistently
    R = jnp.array([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
    pos_rot = (R @ pos.T).T
    pos_new_from_rot = egnn_position_update(pos_rot, feat)
    
    # Should be the same as rotating the original output
    pos_new_then_rot = (R @ pos_new.T).T
    
    print(f"Update then rotate:\n{jnp.round(pos_new_then_rot, 4)}")
    print(f"\nRotate then update:\n{jnp.round(pos_new_from_rot, 4)}")
    print(f"\nEquivariant: {jnp.allclose(pos_new_then_rot, pos_new_from_rot, atol=1e-4)}")