Skip to content

Triton 和 TPUs

CUDA C 功能强大但冗长。 Triton 允许您在 Python 中写入 GPU kernels。 TPUs 提供了具有不同权衡的 GPU 替代方案。本文件涵盖 Triton kernel 编程、Flash Attention 作为案例研究、TPU 架构和 JAX/Pallas 以及如何选择正确的工具。对于 Vulkan 和跨平台 GPU 计算,请参阅文件 07.

  • 上一个文件教了 GPU 在 CUDA C 中编程。该文件攀登了抽象阶梯:在 Python 中,Triton 可以为您提供 CUDA 80% 的性能,而只需 20% 的努力。 TPUs 和 Vulkan 为特定用例提供替代硬件目标。

Triton:Python 中的 GPU kernel

  • Triton (OpenAI) 是一种基于 Python 的语言,用于编写 GPU kernels。您不是对单个 threads (CUDA) 进行inference,而是对数据的 blocks 进行inference。 Triton 的编译器自动处理 thread 映射、内存merge、shared memory 管理和许多优化。

  • 为什么 Triton 很重要:CUDA C 需要深入了解 warp 调度、shared memory 存储体conflict、register 压力和merge模式。 Triton 抽象了其中的大部分内容,使得了解 Python 但不了解系统编程的 ML 研究人员可以进行 GPU kernel 开发。

您的第一个 Triton kernel

import triton
import triton.language as tl
import torch

@triton.jit
def add_kernel(
    x_ptr, y_ptr, output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,  # compile-time constant
):
    # Each program instance processes one block of BLOCK_SIZE elements
    pid = tl.program_id(axis=0)  # which block am I?
    block_start = pid * BLOCK_SIZE

    # Offsets for this block
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # Mask to handle the case where n_elements is not a multiple of BLOCK_SIZE
    mask = offsets < n_elements

    # Load data (masked: out-of-bounds reads return 0)
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # Compute
    output = x + y

    # Store result
    tl.store(output_ptr + offsets, output, mask=mask)


def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    output = torch.empty_like(x)
    n_elements = output.numel()

    # Launch: one program per block
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)

    return output


# Usage
x = torch.randn(1000000, device='cuda')
y = torch.randn(1000000, device='cuda')
z = add(x, y)
  • 与 CUDA 的主要区别
    • 没有明确的 thread 管理。你认为是blocks(程序),而不是threads。
    • tl.arange(0, BLOCK_SIZE) 为整个 block 创建偏移量 vector。对这个vector的所有操作都隐式地是vectorised。
    • mask 处理边界条件(如 AVX-512 掩码 registers,文件 03)。不需要标量清理循环。
    • tl.loadtl.store 自动处理merge访问。
    • @triton.jit在第一次调用时将函数编译为PTX(GPU汇编),然后缓存编译后的kernel。

Triton Softmax kernel

  • Softmax 是一个很棒的 Triton 示例,因为它需要对数据进行多次传递(最大值、减法、指数、求和、除法),并且可以在传递之间将数据保留在 SRAM (shared memory) 中:
@triton.jit
def softmax_kernel(
    output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    # Each program handles one row
    row_idx = tl.program_id(0)
    row_start = input_ptr + row_idx * input_row_stride

    # Load the row
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    row = tl.load(row_start + col_offsets, mask=mask, other=-float('inf'))

    # Softmax: max for numerical stability, then exp, then normalise
    row_max = tl.max(row, axis=0)
    numerator = tl.exp(row - row_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator

    # Store result
    output_start = output_ptr + row_idx * output_row_stride
    tl.store(output_start + col_offsets, softmax_output, mask=mask)
  • 在 PyTorch 中,F.softmax(x, dim=-1) 启动 3 个独立的 kernels(最大值、求和、除法),每个对 global memory 进行读取和写入。 Triton 版本在一个 kernel 中完成所有操作,将数据保存在 registers/SRAM 中。这种 kernel 融合 就是为什么自定义 Triton kernels 可以比 PyTorch 的内置操作快 2-4 倍。

Triton 自整定

  • Triton 支持自动调整:尝试多种配置并选择最快的:
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}),
    ],
    key=['M', 'N', 'K'],  # re-tune when these change
)
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, ...):
    ...
  • Triton 在实际硬件上对每个 configuration 进行基准测试并选择最快的。最佳图block大小取决于 GPU 架构、矩阵尺寸和内存布局 - 自动调整无需手动实验即可找到它们。

Triton 与 CUDA:何时使用各自

Triton CUDA C
语言 Python C/C++
抽象 block级 thread级
开发速度 快速(每个 kernel 10-50 行) 慢速(100-500 行)
绩效上限 手动调校 CUDA 的 ~80-95% 100%(全硬件控制)
shared memory 自动的 手动的
聚结 自动的 手动的
扭曲级别基元 有限的 完整(随机播放、投票等)
硬件支持 仅NVIDIA(AMD实验) 仅NVIDIA
  • 使用 Triton 用于:融合 kernels、自定义 attention 模式、激活函数、大多数 ML 研究 kernel 需要。
  • 使用 CUDA C 用于:最大性能(最后 5-20%)、warp 级原语、复杂的数据相关并行性(当 Triton 无法表达您的模式时)。

案例研究:Flash Attention

  • Flash Attention(Dao 等人,2022)是最近 ML 中最具影响力的自定义 kernel。它使用 \(O(n)\) 内存而不是 \(O(n^2)\) 来计算 attention,从而实现更长的序列。

  • 问题:标准 attention 计算 \(\text{softmax}(QK^T / \sqrt{d}) \cdot V\)\(QK^T\) 矩阵是 \(n \times n\),其中 \(n\) 是序列长度。对于 \(n = 128K\),该矩阵为 \(128K \times 128K \times 4\) 字节 = 64 GB。它不适合 GPU 内存。

  • 见解:您不需要具体化完整的 \(n \times n\) 矩阵。在 tiles 中计算 attention:加载 \(Q\) 的 block、\(K\) 的 block,计算它们的部分 attention 分数,累加,然后移动到下一个 block。 \(n \times n\) 矩阵从未完全具体化——SRAM 中一次仅存在一个图block。

  • 在线 softmax:棘手的部分是 softmax,它需要知道整行的最大值(为了数值稳定性)。 Flash Attention 使用在线 softmax 技巧:维持运行的最大值,并在找到新的最大值时重新调整先前计算的值。这使得 softmax 可以增量计算,一次一个图block。

  • 算法:

For each block of Q rows:
    For each block of K columns:
        1. Load Q_block from HBM to SRAM
        2. Load K_block from HBM to SRAM
        3. Compute S_block = Q_block @ K_block.T (in SRAM)
        4. Update running max, rescale previous results
        5. Compute exp(S_block - running_max)
        6. Update running sum and output accumulator
    Load V_block and compute final output
    Write output block back to HBM
  • 为什么快:内循环完全在SRAM(shared memory)中运行。全局存储器(HBM)仅被访问以加载Q、K、V的blocks并写入最终输出。数据重用系数与 SRAM 大小成正比,访问速度比 HBM 快约 100 倍。

  • Flash Attention 在 Triton 和 CUDA C 中均实现。 CUDA 版本速度更快(效率提高约 10%),但 Triton 版本的可读性和可修改性要高得多,这对于新 attention 变体的研究很重要。

TPU架构

  • TPUs(张量处理单元)是 Google 的定制 ML 加速器。他们采用了与 GPU 完全不同的方法:

  • 脉动阵列:TPU 的核心计算单元是一个矩阵乘法单元 (MXU),这是一个 128×128 或 256×256 脉动阵列,通过使数据流经乘法累加单元的 grid 来计算矩阵乘法。数据从边缘进入并通过数组传播,每个单元执行一次乘加并将结果传递到下一个单元。

  • 与 GPU(调度数千个独立的 threads)不同,脉动阵列是单个确定性数据流。没有thread调度,没有warp发散,没有branch预测。这种简单性使得 MXU 在矩阵乘法方面极其节能。

  • HBM:TPUs 使用与 GPU 相同的高带宽内存。 TPU v5e 每个芯片具有 16 GB HBM2e; TPU v5p 具有 95 GB HBM2e。

  • ICI(芯片间互连):TPU Pod 通过定制高速网络连接数百个 TPUs。 JAX 原生支持 TPU Pod 之间的数据并行性和 model 并行性(第 6 章)。

  • BFloat16:TPUs 是第一个使用 bfloat16 的人(第 13 章,文件 02)。 BF16 具有与 float32 相同的指数范围(防止 training 期间溢出),但尾数精度较低。这种权衡对于 ML 来说是理想的选择,因为梯度值的范围很广,但不需要 23 位精度。

编程 TPUs:JAX 和 Pallas

  • TPUs通过JAXXLA进行编程。您编写 Python/JAX 代码,jax.jit 将其编译为 XLA HLO,XLA 将 HLO 编译为 TPU 特定指令。没有CUDA,没有C++。
import jax
import jax.numpy as jnp

@jax.jit
def matmul(a, b):
    return jnp.dot(a, b)

# This runs on CPU, GPU, or TPU depending on the device
a = jnp.ones((1024, 1024))
b = jnp.ones((1024, 1024))
c = matmul(a, b)
  • Pallas 是 JAX 的 kernel 创作的 API — Triton 的 JAX 等效项。它允许您编写 XLA 为 GPU 或 TPU 编译的低级 kernels:
from jax.experimental import pallas as pl
import jax.numpy as jnp

def add_kernel(x_ref, y_ref, o_ref):
    o_ref[...] = x_ref[...] + y_ref[...]

def add_pallas(x, y):
    return pl.pallas_call(
        add_kernel,
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
        grid=(x.shape[0] // 128,),
        in_specs=[pl.BlockSpec((128,), lambda i: (i,)),
                  pl.BlockSpec((128,), lambda i: (i,))],
        out_specs=pl.BlockSpec((128,), lambda i: (i,)),
    )(x, y)
  • Pallas 比 Triton 更新且不太成熟,但它是为 TPUs 编写自定义 kernels 的唯一方法(因为 TPUs 不支持 CUDA)。

GPU 与 TPU 对比

GPU (NVIDIA) TPU(谷歌)
可用性 任何本地云 仅限谷歌云
编程 CUDA C、Triton、PyTorch JAX/XLA、Pallas
灵活性 通用计算 针对矩阵密集型机器学习进行了优化
峰值矩阵相乘 FLOPS 非常高(张量核心) 非常高 (MXU)
非矩阵乘运算 好的 较慢(通过 vector 单元,而不是 MXU 路由)
多芯片缩放 NVLink(8 个 GPU)、InfiniBand ICI(数千个TPUs,集成更紧密)
成本效率 竞争的 大型training通常更便宜
生态系统 最大(PyTorch、TensorFlow、JAX) JAX-专注
  • 使用 GPU 用于:大多数 ML 工作负载、基于 PyTorch 的研究、inference 服务、具有大量非 matmul 计算的工作负载。
  • 将 TPUs 用于:大规模 JAX training(数千个芯片)、Google Cloud 上对成本敏感的 training、以矩阵乘法为主的工作负载。

选择正确的工具

工作量 最佳工具 为什么
ML training(PyTorch) NVIDIA GPU + CUDA/Triton 最大的生态系统,最好的工具
ML training(JAX,大型) TPU 或 NVIDIA GPU TPU 用于 Google 规模的成本,GPU 用于灵活性
定制熔断器 kernels Triton (Python) 或 CUDA C Triton 实现开发速度,CUDA 实现峰值性能
JAX 定制 kernels Pallas TPU 的唯一选项,也适用于 GPU
跨平台inference Vulkan(文件 07)或 ONNX 运行时 可在任何 GPU 供应商上运行
移动/边缘 inference 金属(苹果)、Vulkan(安卓)、NNAPI 特定于平台的加速器
浏览器 inference WebGPU(文件07) 浏览器中唯一的选项
CPU-仅inference ONNX 运行时 + AVX/NEON 不需要 GPU,使用 SIMD(文件 02-03)
新颖的硬件 特定于供应商的 SDK 每个加速器都有自己的工具链

编码任务(使用带有 GPU 运行时的 CoLab)

  1. 编写并运行 Triton kernel 以进行 vector 加法。将其性能与 PyTorch 的内置附加功能进行比较。

    import triton
    import triton.language as tl
    import torch
    import time
    
    @triton.jit
    def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
        pid = tl.program_id(0)
        offs = pid * BLOCK + tl.arange(0, BLOCK)
        mask = offs < n
        x = tl.load(x_ptr + offs, mask=mask)
        y = tl.load(y_ptr + offs, mask=mask)
        tl.store(out_ptr + offs, x + y, mask=mask)
    
    n = 10_000_000
    x = torch.randn(n, device='cuda')
    y = torch.randn(n, device='cuda')
    
    # Triton
    out_triton = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
    add_kernel[grid](x, y, out_triton, n, BLOCK=1024)
    
    # PyTorch
    out_torch = x + y
    
    # Verify correctness
    assert torch.allclose(out_triton, out_torch, atol=1e-5)
    
    # Benchmark
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(1000):
        add_kernel[grid](x, y, out_triton, n, BLOCK=1024)
    torch.cuda.synchronize()
    triton_time = (time.time() - start) / 1000
    
    start = time.time()
    for _ in range(1000):
        out_torch = x + y
    torch.cuda.synchronize()
    torch_time = (time.time() - start) / 1000
    
    print(f"Triton:  {triton_time*1000:.3f} ms")
    print(f"PyTorch: {torch_time*1000:.3f} ms")
    print(f"Ratio:   {torch_time/triton_time:.2f}x")
    

  2. 编写一个 Triton 融合的 kernel,它在一次传递中执行乘法+加法+ReLU。与三个单独的 PyTorch 操作进行比较。

    import triton
    import triton.language as tl
    import torch
    import time
    
    @triton.jit
    def fused_mul_add_relu_kernel(x_ptr, w_ptr, b_ptr, out_ptr, n, BLOCK: tl.constexpr):
        pid = tl.program_id(0)
        offs = pid * BLOCK + tl.arange(0, BLOCK)
        mask = offs < n
        x = tl.load(x_ptr + offs, mask=mask)
        w = tl.load(w_ptr + offs, mask=mask)
        b = tl.load(b_ptr + offs, mask=mask)
        result = tl.maximum(x * w + b, 0.0)  # fused: mul + add + relu
        tl.store(out_ptr + offs, result, mask=mask)
    
    n = 10_000_000
    x = torch.randn(n, device='cuda')
    w = torch.randn(n, device='cuda')
    b = torch.randn(n, device='cuda')
    
    # Fused (Triton)
    out_fused = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
    fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024)
    
    # Unfused (PyTorch)
    out_unfused = torch.relu(x * w + b)
    
    assert torch.allclose(out_fused, out_unfused, atol=1e-5)
    
    # Benchmark
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(1000):
        fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024)
    torch.cuda.synchronize()
    fused_time = (time.time() - start) / 1000
    
    start = time.time()
    for _ in range(1000):
        out_unfused = torch.relu(x * w + b)
    torch.cuda.synchronize()
    unfused_time = (time.time() - start) / 1000
    
    print(f"Fused (Triton):    {fused_time*1000:.3f} ms")
    print(f"Unfused (PyTorch): {unfused_time*1000:.3f} ms")
    print(f"Speedup:           {unfused_time/fused_time:.2f}x")
    

  3. 测量 JAX 的 XLA 编译器如何自动融合操作。比较有和没有 jit 的操作链。

    import jax
    import jax.numpy as jnp
    import time
    
    def chain_ops(x):
        x = x * 2.0
        x = x + 1.0
        x = jnp.maximum(x, 0.0)  # ReLU
        x = x / jnp.sum(x)
        return x
    
    chain_jit = jax.jit(chain_ops)
    x = jax.random.normal(jax.random.PRNGKey(0), (10000, 1000))
    
    # Warm up
    _ = chain_jit(x)
    jax.block_until_ready(_)
    
    # Eager (each op is a separate kernel launch)
    start = time.time()
    for _ in range(100):
        y = chain_ops(x)
    jax.block_until_ready(y)
    eager_time = (time.time() - start) / 100
    
    # JIT (XLA fuses operations)
    start = time.time()
    for _ in range(100):
        y = chain_jit(x)
    jax.block_until_ready(y)
    jit_time = (time.time() - start) / 100
    
    print(f"Eager: {eager_time*1000:.2f} ms")
    print(f"JIT:   {jit_time*1000:.2f} ms")
    print(f"Speedup: {eager_time/jit_time:.1f}x (XLA fuses the 4 operations into 1 kernel)")