Skip to content

为什么用 C++ 以及 ML 框架如何工作

本书中每一个 jnp.matmul、每一个 torch.nn.Linear、每一个 np.dot 调用,底层都在执行 C++ 和 CUDA 代码。本文揭开这层面纱:ML 框架为什么这样构建、面向 Python 工程师的 C++ 快速入门、何时需要编写自定义 C++ kernel,以及如何将其绑定到 Python——这是你写的代码与运行它的硬件之间的桥梁。

  • 你已经写了 15 章的 Python。你 import 了 JAX,调用了 jax.grad,运行了训练循环,构建了模型。一切感觉都像 Python。但事实是:几乎所有实际的计算都不是在 Python 中发生的。

  • 当你在 PyTorch 中写 output = model(input),或在 JAX 中写 output = jnp.matmul(W, x) 时,Python 几乎什么都没做。它构建了一个对计算的描述(一个操作图),然后将其交给 C++/CUDA 后端去完成真正的工作。Python 是方向盘,C++ 是发动机。

为什么是 Python 前端、C++ 后端

  • 这种双语言架构之所以存在,是因为 Python 和 C++ 各自擅长完全相反的事情:
Python C++
开发速度 快(动态类型、REPL、无需编译) 慢(静态类型、头文件、编译时间)
执行速度 比 C 慢约 100 倍(解释型,GIL) 接近硬件速度(编译型,无开销)
内存控制 自动(GC),无法控制布局 手动,对每个字节精确控制
硬件访问 无(无 SIMD、无 GPU、无自定义内存) 完整(intrinsic、CUDA、内联 assembly)
生态系统 ML 丰富(Notebook、可视化、数据) 系统丰富(OS、驱动、引擎)
  • 核心思路:让每种语言做它擅长的事。Python 处理人类生产力重要的部分(实验设计、超参数调优、数据探索)。C++ 处理机器性能重要的部分(矩阵乘法、卷积、attention kernel)。

  • 单次矩阵乘法 jnp.matmul(A, B),其中 \(A\)\(4096 \times 4096\),需要执行约 1370 亿次浮点运算。用纯 Python(嵌套循环)大约需要 30 分钟。用带 AVX-512 SIMD 和多线程的优化 C++ 大约需要 10 毫秒。这是 18 万倍的差距。任何 Python 技巧都无法弥补这个差距。

ML 框架的结构

  • 每个主流 ML 框架都遵循相同的架构:
用户代码(Python)
Python API 层(torch.nn, jax.numpy, numpy)
Dispatch / JIT 编译器(torch.compile, XLA, NumPy dispatch)
C++ kernel 库(ATen/PyTorch, XLA, BLAS/LAPACK)
硬件专用后端(CUDA, cuDNN, MKL, oneDNN, Metal)
硬件(CPU SIMD 单元、GPU 核心、TPU MXU)

NumPy

  • NumPy 的核心用 C 编写。当你调用 np.dot(A, B) 时,Python 调用一个 C 函数,该函数再调用 BLAS(Basic Linear Algebra Subprograms),通常是 Intel MKL 或 OpenBLAS。BLAS 是手工优化的 C 和 Fortran 代码,使用 SIMD 指令、缓存感知内存访问模式和多线程。数十年的优化使矩阵乘法变得极快。

  • NumPy 只使用 CPU,不使用 GPU。但在 CPU 上,它因为委托给了最佳可用的 BLAS 实现而极其高效。

PyTorch

  • PyTorch 的计算引擎是 ATen(A Tensor Library),用 C++ 编写。ATen 实现了约 2000 个 tensor 操作(add、matmul、conv2d、softmax 等),每个操作都有 CPU 和 CUDA 后端。

  • 当你调用 torch.matmul(A, B) 时:

    1. Python dispatch 到 ATen C++ 函数。
    2. ATen 检查设备(CPU 或 CUDA)和 dtype。
    3. 在 CPU 上:调用 MKL/OpenBLAS。在 GPU 上:调用 cuBLAS(NVIDIA 的 GPU 优化 BLAS)。
    4. 结果被包装成 Python tensor 对象并返回。
  • torch.compile(PyTorch 2.0+)更进一步:它追踪你的 Python 代码,构建计算图,并使用 Triton(GPU)或 C++/OpenMP(CPU)编译。编译后的代码融合操作、消除 Python 开销,可以比 eager 模式快 2-5 倍。

JAX

  • JAX 将 Python 函数编译为 XLA(Accelerated Linear Algebra),Google 用于 ML 工作负载的编译器。当你对函数使用 jax.jit 时:

    1. JAX 追踪函数,将操作捕获为 XLA 计算图(HLO——高级操作)。
    2. XLA 优化图:融合操作、消除冗余计算、优化内存布局。
    3. XLA 编译到目标后端:CPU(通过 LLVM)、GPU(通过 CUDA/PTX)或 TPU(通过 TPU 专用指令)。
    4. 编译后的代码直接在硬件上运行,无任何 Python 参与。
  • 这就是为什么 jax.jit 如此重要:没有它,每个操作都是一次独立的 Python→C++ 往返。有了它,整个函数就是一个编译好的 kernel。

面向 Python 工程师的 C++ 快速入门

  • 你不需要成为 C++ 专家。你需要足够读懂 kernel 代码、编写简单扩展,并理解性能讨论。以下是核心要点。

类型与变量

// C++ 需要显式类型(与 Python 不同)
int count = 0;           // 32 位整数
float loss = 0.5f;       // 32 位 float
double lr = 3e-4;        // 64 位 float
bool training = true;    // 布尔值

// 数组(固定大小,在栈上分配)
float weights[1024];     // 1024 个 float,在内存中连续存储

// 指针:持有内存地址的变量
float* ptr = weights;    // ptr 指向 weights 的第一个元素
float val = ptr[42];     // 通过指针算术访问第 42 号元素
// ptr[42] 等价于 *(ptr + 42)
  • 指针是与 Python 最大的概念差异。在 Python 中,一切都是引用,你从不需要考虑内存地址。在 C++ 中,指针让你直接访问内存——强大但危险(悬空指针、缓冲区溢出)。

函数

// 函数声明:返回类型 函数名(参数类型 参数名)
float relu(float x) {
    return x > 0.0f ? x : 0.0f;
}

// 按引用传递(避免复制大对象)
void scale_vector(std::vector<float>& vec, float factor) {
    for (size_t i = 0; i < vec.size(); i++) {
        vec[i] *= factor;
    }
}

// const 引用:只读,不复制
float sum(const std::vector<float>& vec) {
    float total = 0.0f;
    for (float x : vec) {  // 基于范围的 for 循环(类似 Python 的 for x in vec)
        total += x;
    }
    return total;
}

内存:栈 vs 堆

// 栈分配:快,自动生命周期(函数返回时释放)
float buffer[256];   // 栈上的 256 个 float

// 堆分配:手动,函数结束后仍存在
float* data = new float[n];   // 在堆上分配 n 个 float
// ... 使用 data ...
delete[] data;                 // 你必须手动释放(没有垃圾回收器)

// 现代 C++:智能指针(自动清理,类似 Python 引用)
#include <memory>
auto data = std::make_unique<float[]>(n);  // 超出作用域时自动释放
  • 核心规则:栈快但有限(通常 1-8 MB)。大型数组(tensor、feature map)必须放在堆上。在 Python 中,一切都在堆上,GC 负责清理。在 C++ 中,你自己管理(或使用智能指针)。

模板(泛型)

// 适用于任何数值类型的函数
template <typename T>
T add(T a, T b) {
    return a + b;
}

add<float>(1.5f, 2.5f);   // 返回 4.0f
add<int>(3, 4);             // 返回 7
  • 模板是 C++ 库(如 ATen)在不重复实现的情况下编写适用于 float16、float32、float64 等的代码的方式。

标准库要点

#include <vector>      // 动态数组(类似 Python list)
#include <string>      // 字符串类型
#include <unordered_map>  // 哈希表(类似 Python dict)
#include <algorithm>   // sort, find, transform 等
#include <cmath>       // 数学函数

std::vector<float> vec = {1.0f, 2.0f, 3.0f};
vec.push_back(4.0f);            // 追加
float first = vec[0];           // 索引
size_t len = vec.size();        // 长度

std::unordered_map<std::string, int> counts;
counts["hello"] = 5;            // 插入
if (counts.count("hello")) { }  // 检查是否存在

何时编写自定义 C++ Kernel

  • 大多数 ML 工程师永远不需要写 C++。框架内置的操作覆盖了 99% 的用例。只有在以下情况下才考虑自定义 C++:

  • 你需要的操作在框架中不存在:一个新颖的激活函数、自定义 attention 模式、无法用现有操作组合表达的特殊损失函数。

  • 为性能融合操作:你的模型执行 relu(layernorm(matmul(x, W) + b))。每个操作启动一个单独的 kernel,读写内存并同步。融合的 kernel 在一次遍历中完成所有操作,避免内存往返。这可以快 2-5 倍。

  • 减少内存占用:自定义 kernel 可以在不存储所有中间激活值的情况下计算梯度(kernel 层面的梯度检查点)。

  • 针对新型硬件:新加速器(如 Cerebras、Groq)可能没有框架支持。你直接编写 kernel。

  • 对于情况 1-2,Triton(第 16 章,第 05 文件)通常已足够,且比直接写 CUDA C 容易得多。只有当 Triton 无法表达你需要的模式时,才降级到 CUDA C。

如何将 C++ 绑定到 Python

  • 写 C++ 是工作的一半。你还需要从 Python 中调用它。

pybind11(通用用途)

  • pybind11 用最少的样板代码为 C++ 函数创建 Python 绑定:
// my_ops.cpp
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
namespace py = pybind11;

// 一个简单的自定义操作
py::array_t<float> custom_relu(py::array_t<float> input) {
    auto buf = input.request();
    float* ptr = static_cast<float*>(buf.ptr);
    size_t n = buf.size;

    auto result = py::array_t<float>(n);
    float* out = static_cast<float*>(result.request().ptr);

    for (size_t i = 0; i < n; i++) {
        out[i] = ptr[i] > 0 ? ptr[i] : 0;
    }
    return result;
}

PYBIND11_MODULE(my_ops, m) {
    m.def("custom_relu", &custom_relu, "自定义 ReLU 操作");
}
# 编译
pip install pybind11
c++ -O3 -shared -std=c++17 -fPIC $(python3 -m pybind11 --includes) my_ops.cpp -o my_ops$(python3-config --extension-suffix)
# 从 Python 中使用
import my_ops
import numpy as np

x = np.array([-1.0, 2.0, -3.0, 4.0], dtype=np.float32)
y = my_ops.custom_relu(x)
print(y)  # [0. 2. 0. 4.]

PyTorch C++ 扩展

  • PyTorch 提供了一种添加自定义操作的简化方式:
// custom_op.cpp
#include <torch/extension.h>

torch::Tensor custom_gelu(torch::Tensor x) {
    return x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("custom_gelu", &custom_gelu, "自定义 GELU 激活函数");
}
# 即时加载并编译
from torch.utils.cpp_extension import load

custom_ops = load(
    name="custom_ops",
    sources=["custom_op.cpp"],
    extra_cflags=["-O3"],
)

x = torch.randn(1000)
y = custom_ops.custom_gelu(x)
  • torch.utils.cpp_extension.load 编译 C++ 代码,创建共享库,并将其作为 Python 模块加载,一步完成。这是在 PyTorch 中实验自定义 C++ 操作的最简单方式。

JAX 自定义调用

  • JAX 使用 XLA 自定义调用。过程更复杂(你需要向 XLA 注册一个 C 函数),但概念相同:写 C/C++,绑定,从 Python 调用。

  • 对大多数 JAX 用户来说,Pallas(见第 05 文件)是更好的选择:它让你用类 Python 语法编写 GPU kernel,由 XLA 编译,无需离开 JAX 生态系统。

全局视角

  • 本文解释了 Python 与硬件之间的那一层。本章其余文件将深入探讨:

    • 第 01 文件:硬件本身(CPU 架构、GPU 架构、内存系统)
    • 第 02-03 文件:CPU 上的 SIMD 编程(ARM NEON、x86 AVX)——在这里你编写使用 CPU 向量单元的 C++
    • 第 04 文件:使用 CUDA 的 GPU 编程——在这里你编写运行在数千个 GPU 核心上的 C++
    • 第 05 文件:Triton、Pallas 和更高级别的 GPU 编程——在这里你编写可编译为 GPU kernel 的 Python
  • 这一进展反映了抽象阶梯:C++ intrinsic(最底层,最高控制权)→ CUDA(GPU 专用)→ Triton/Pallas(Pythonic,编译型)→ JAX/PyTorch(最高层,自动化)。每一层用控制权换取便利性。理解低层级会让你更好地使用高层级。

编程任务(用 g++ 或 clang++ 编译)

  1. 写你的第一个 C++ 程序。分配一个数组,填充它,计算总和,并测量时间。这介绍了编译、数组、指针和计时。

    // task1_basics.cpp
    // 编译:g++ -O3 -o task1 task1_basics.cpp
    // 运行:./task1
    
    #include <iostream>
    #include <chrono>
    #include <vector>
    
    int main() {
        const int N = 10'000'000;  // C++ 允许用 ' 作为数字分隔符
        std::vector<float> data(N);
    
        // 填充数组
        for (int i = 0; i < N; i++) {
            data[i] = static_cast<float>(i) * 0.001f;
        }
    
        // 计算总和
        auto start = std::chrono::high_resolution_clock::now();
        float sum = 0.0f;
        for (int i = 0; i < N; i++) {
            sum += data[i];
        }
        auto end = std::chrono::high_resolution_clock::now();
        double elapsed = std::chrono::duration<double, std::milli>(end - start).count();
    
        std::cout << "Sum: " << sum << std::endl;
        std::cout << "Time: " << elapsed << " ms" << std::endl;
        std::cout << "Elements: " << N << std::endl;
        std::cout << "Throughput: " << (N * sizeof(float)) / elapsed / 1e6 << " GB/s" << std::endl;
    
        return 0;
    }
    

  2. 编写一个 C++ 函数对数组执行 ReLU,然后使用 pybind11 构建 Python 绑定。从 Python 中调用它,并与 NumPy 的速度进行比较。

    // task2_relu.cpp
    // 编译:c++ -O3 -shared -std=c++17 -fPIC $(python3 -m pybind11 --includes) \
    //          task2_relu.cpp -o my_relu$(python3-config --extension-suffix)
    
    #include <pybind11/pybind11.h>
    #include <pybind11/numpy.h>
    namespace py = pybind11;
    
    py::array_t<float> cpp_relu(py::array_t<float> input) {
        auto buf = input.request();
        float* ptr = static_cast<float*>(buf.ptr);
        int n = buf.size;
    
        auto result = py::array_t<float>(n);
        float* out = static_cast<float*>(result.request().ptr);
    
        for (int i = 0; i < n; i++) {
            out[i] = ptr[i] > 0.0f ? ptr[i] : 0.0f;
        }
        return result;
    }
    
    PYBIND11_MODULE(my_relu, m) {
        m.def("relu", &cpp_relu, "C++ ReLU");
    }
    
    # test_relu.py — 编译上面的 C++ 模块后运行
    import numpy as np
    import time
    import my_relu  # 编译好的 C++ 模块
    
    x = np.random.randn(10_000_000).astype(np.float32)
    
    # C++ ReLU
    start = time.time()
    for _ in range(100):
        y_cpp = my_relu.relu(x)
    cpp_time = (time.time() - start) / 100
    
    # NumPy ReLU
    start = time.time()
    for _ in range(100):
        y_np = np.maximum(x, 0)
    np_time = (time.time() - start) / 100
    
    print(f"C++ ReLU:   {cpp_time*1000:.2f} ms")
    print(f"NumPy ReLU: {np_time*1000:.2f} ms")
    print(f"Match: {np.allclose(y_cpp, y_np)}")
    

  3. 编写一个 C++ 程序,演示为什么内存布局很重要。比较行优先与列优先的访问模式并测量性能差异。

    // task3_layout.cpp
    // 编译:g++ -O3 -o task3 task3_layout.cpp
    
    #include <iostream>
    #include <chrono>
    #include <vector>
    
    int main() {
        const int N = 4096;
        std::vector<float> matrix(N * N, 1.0f);
    
        // 行优先访问:连续内存地址(缓存友好)
        auto start = std::chrono::high_resolution_clock::now();
        float sum_row = 0.0f;
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < N; j++) {
                sum_row += matrix[i * N + j];  // stride-1 访问
            }
        }
        auto end = std::chrono::high_resolution_clock::now();
        double row_ms = std::chrono::duration<double, std::milli>(end - start).count();
    
        // 列优先访问:stride-N 访问(缓存不友好)
        start = std::chrono::high_resolution_clock::now();
        float sum_col = 0.0f;
        for (int j = 0; j < N; j++) {
            for (int i = 0; i < N; i++) {
                sum_col += matrix[i * N + j];  // stride-N 访问(缓存 miss!)
            }
        }
        end = std::chrono::high_resolution_clock::now();
        double col_ms = std::chrono::duration<double, std::milli>(end - start).count();
    
        std::cout << "行优先(缓存友好): " << row_ms << " ms" << std::endl;
        std::cout << "列优先(缓存不友好): " << col_ms << " ms" << std::endl;
        std::cout << "慢了: " << col_ms / row_ms << "x" << std::endl;
        std::cout << "(两个总和: " << sum_row << ", " << sum_col << ")" << std::endl;
    
        return 0;
    }