Triton 和 TPUs¶
CUDA C 功能强大但冗长。 Triton 允许您在 Python 中写入 GPU kernels。 TPUs 提供了具有不同权衡的 GPU 替代方案。本文件涵盖 Triton kernel 编程、Flash Attention 作为案例研究、TPU 架构和 JAX/Pallas 以及如何选择正确的工具。对于 Vulkan 和跨平台 GPU 计算,请参阅文件 07.
- 上一个文件教了 GPU 在 CUDA C 中编程。该文件攀登了抽象阶梯:在 Python 中,Triton 可以为您提供 CUDA 80% 的性能,而只需 20% 的努力。 TPUs 和 Vulkan 为特定用例提供替代硬件目标。
Triton:Python 中的 GPU kernel¶
-
Triton (OpenAI) 是一种基于 Python 的语言,用于编写 GPU kernels。您不是对单个 threads (CUDA) 进行inference,而是对数据的 blocks 进行inference。 Triton 的编译器自动处理 thread 映射、内存merge、shared memory 管理和许多优化。
-
为什么 Triton 很重要:CUDA C 需要深入了解 warp 调度、shared memory 存储体conflict、register 压力和merge模式。 Triton 抽象了其中的大部分内容,使得了解 Python 但不了解系统编程的 ML 研究人员可以进行 GPU kernel 开发。
您的第一个 Triton kernel¶
import triton
import triton.language as tl
import torch
@triton.jit
def add_kernel(
x_ptr, y_ptr, output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr, # compile-time constant
):
# Each program instance processes one block of BLOCK_SIZE elements
pid = tl.program_id(axis=0) # which block am I?
block_start = pid * BLOCK_SIZE
# Offsets for this block
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Mask to handle the case where n_elements is not a multiple of BLOCK_SIZE
mask = offsets < n_elements
# Load data (masked: out-of-bounds reads return 0)
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# Compute
output = x + y
# Store result
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()
# Launch: one program per block
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
# Usage
x = torch.randn(1000000, device='cuda')
y = torch.randn(1000000, device='cuda')
z = add(x, y)
- 与 CUDA 的主要区别:
- 没有明确的 thread 管理。你认为是blocks(程序),而不是threads。
tl.arange(0, BLOCK_SIZE)为整个 block 创建偏移量 vector。对这个vector的所有操作都隐式地是vectorised。mask处理边界条件(如 AVX-512 掩码 registers,文件 03)。不需要标量清理循环。tl.load和tl.store自动处理merge访问。@triton.jit在第一次调用时将函数编译为PTX(GPU汇编),然后缓存编译后的kernel。
Triton Softmax kernel¶
- Softmax 是一个很棒的 Triton 示例,因为它需要对数据进行多次传递(最大值、减法、指数、求和、除法),并且可以在传递之间将数据保留在 SRAM (shared memory) 中:
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr,
):
# Each program handles one row
row_idx = tl.program_id(0)
row_start = input_ptr + row_idx * input_row_stride
# Load the row
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
row = tl.load(row_start + col_offsets, mask=mask, other=-float('inf'))
# Softmax: max for numerical stability, then exp, then normalise
row_max = tl.max(row, axis=0)
numerator = tl.exp(row - row_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Store result
output_start = output_ptr + row_idx * output_row_stride
tl.store(output_start + col_offsets, softmax_output, mask=mask)
- 在 PyTorch 中,
F.softmax(x, dim=-1)启动 3 个独立的 kernels(最大值、求和、除法),每个对 global memory 进行读取和写入。 Triton 版本在一个 kernel 中完成所有操作,将数据保存在 registers/SRAM 中。这种 kernel 融合 就是为什么自定义 Triton kernels 可以比 PyTorch 的内置操作快 2-4 倍。
Triton 自整定¶
- Triton 支持自动调整:尝试多种配置并选择最快的:
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}),
],
key=['M', 'N', 'K'], # re-tune when these change
)
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, ...):
...
- Triton 在实际硬件上对每个 configuration 进行基准测试并选择最快的。最佳图block大小取决于 GPU 架构、矩阵尺寸和内存布局 - 自动调整无需手动实验即可找到它们。
Triton 与 CUDA:何时使用各自¶
| Triton | CUDA C | |
|---|---|---|
| 语言 | Python | C/C++ |
| 抽象 | block级 | thread级 |
| 开发速度 | 快速(每个 kernel 10-50 行) | 慢速(100-500 行) |
| 绩效上限 | 手动调校 CUDA 的 ~80-95% | 100%(全硬件控制) |
| shared memory | 自动的 | 手动的 |
| 聚结 | 自动的 | 手动的 |
| 扭曲级别基元 | 有限的 | 完整(随机播放、投票等) |
| 硬件支持 | 仅NVIDIA(AMD实验) | 仅NVIDIA |
- 使用 Triton 用于:融合 kernels、自定义 attention 模式、激活函数、大多数 ML 研究 kernel 需要。
- 使用 CUDA C 用于:最大性能(最后 5-20%)、warp 级原语、复杂的数据相关并行性(当 Triton 无法表达您的模式时)。
案例研究:Flash Attention¶
-
Flash Attention(Dao 等人,2022)是最近 ML 中最具影响力的自定义 kernel。它使用 \(O(n)\) 内存而不是 \(O(n^2)\) 来计算 attention,从而实现更长的序列。
-
问题:标准 attention 计算 \(\text{softmax}(QK^T / \sqrt{d}) \cdot V\)。 \(QK^T\) 矩阵是 \(n \times n\),其中 \(n\) 是序列长度。对于 \(n = 128K\),该矩阵为 \(128K \times 128K \times 4\) 字节 = 64 GB。它不适合 GPU 内存。
-
见解:您不需要具体化完整的 \(n \times n\) 矩阵。在 tiles 中计算 attention:加载 \(Q\) 的 block、\(K\) 的 block,计算它们的部分 attention 分数,累加,然后移动到下一个 block。 \(n \times n\) 矩阵从未完全具体化——SRAM 中一次仅存在一个图block。
-
在线 softmax:棘手的部分是 softmax,它需要知道整行的最大值(为了数值稳定性)。 Flash Attention 使用在线 softmax 技巧:维持运行的最大值,并在找到新的最大值时重新调整先前计算的值。这使得 softmax 可以增量计算,一次一个图block。
-
算法:
For each block of Q rows:
For each block of K columns:
1. Load Q_block from HBM to SRAM
2. Load K_block from HBM to SRAM
3. Compute S_block = Q_block @ K_block.T (in SRAM)
4. Update running max, rescale previous results
5. Compute exp(S_block - running_max)
6. Update running sum and output accumulator
Load V_block and compute final output
Write output block back to HBM
-
为什么快:内循环完全在SRAM(shared memory)中运行。全局存储器(HBM)仅被访问以加载Q、K、V的blocks并写入最终输出。数据重用系数与 SRAM 大小成正比,访问速度比 HBM 快约 100 倍。
-
Flash Attention 在 Triton 和 CUDA C 中均实现。 CUDA 版本速度更快(效率提高约 10%),但 Triton 版本的可读性和可修改性要高得多,这对于新 attention 变体的研究很重要。
TPU架构¶
-
TPUs(张量处理单元)是 Google 的定制 ML 加速器。他们采用了与 GPU 完全不同的方法:
-
脉动阵列:TPU 的核心计算单元是一个矩阵乘法单元 (MXU),这是一个 128×128 或 256×256 脉动阵列,通过使数据流经乘法累加单元的 grid 来计算矩阵乘法。数据从边缘进入并通过数组传播,每个单元执行一次乘加并将结果传递到下一个单元。
-
与 GPU(调度数千个独立的 threads)不同,脉动阵列是单个确定性数据流。没有thread调度,没有warp发散,没有branch预测。这种简单性使得 MXU 在矩阵乘法方面极其节能。
-
HBM:TPUs 使用与 GPU 相同的高带宽内存。 TPU v5e 每个芯片具有 16 GB HBM2e; TPU v5p 具有 95 GB HBM2e。
-
ICI(芯片间互连):TPU Pod 通过定制高速网络连接数百个 TPUs。 JAX 原生支持 TPU Pod 之间的数据并行性和 model 并行性(第 6 章)。
-
BFloat16:TPUs 是第一个使用 bfloat16 的人(第 13 章,文件 02)。 BF16 具有与 float32 相同的指数范围(防止 training 期间溢出),但尾数精度较低。这种权衡对于 ML 来说是理想的选择,因为梯度值的范围很广,但不需要 23 位精度。
编程 TPUs:JAX 和 Pallas¶
- TPUs通过JAX和XLA进行编程。您编写 Python/JAX 代码,
jax.jit将其编译为 XLA HLO,XLA 将 HLO 编译为 TPU 特定指令。没有CUDA,没有C++。
import jax
import jax.numpy as jnp
@jax.jit
def matmul(a, b):
return jnp.dot(a, b)
# This runs on CPU, GPU, or TPU depending on the device
a = jnp.ones((1024, 1024))
b = jnp.ones((1024, 1024))
c = matmul(a, b)
- Pallas 是 JAX 的 kernel 创作的 API — Triton 的 JAX 等效项。它允许您编写 XLA 为 GPU 或 TPU 编译的低级 kernels:
from jax.experimental import pallas as pl
import jax.numpy as jnp
def add_kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[...]
def add_pallas(x, y):
return pl.pallas_call(
add_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
grid=(x.shape[0] // 128,),
in_specs=[pl.BlockSpec((128,), lambda i: (i,)),
pl.BlockSpec((128,), lambda i: (i,))],
out_specs=pl.BlockSpec((128,), lambda i: (i,)),
)(x, y)
- Pallas 比 Triton 更新且不太成熟,但它是为 TPUs 编写自定义 kernels 的唯一方法(因为 TPUs 不支持 CUDA)。
GPU 与 TPU 对比¶
| GPU (NVIDIA) | TPU(谷歌) | |
|---|---|---|
| 可用性 | 任何本地云 | 仅限谷歌云 |
| 编程 | CUDA C、Triton、PyTorch | JAX/XLA、Pallas |
| 灵活性 | 通用计算 | 针对矩阵密集型机器学习进行了优化 |
| 峰值矩阵相乘 FLOPS | 非常高(张量核心) | 非常高 (MXU) |
| 非矩阵乘运算 | 好的 | 较慢(通过 vector 单元,而不是 MXU 路由) |
| 多芯片缩放 | NVLink(8 个 GPU)、InfiniBand | ICI(数千个TPUs,集成更紧密) |
| 成本效率 | 竞争的 | 大型training通常更便宜 |
| 生态系统 | 最大(PyTorch、TensorFlow、JAX) | JAX-专注 |
- 使用 GPU 用于:大多数 ML 工作负载、基于 PyTorch 的研究、inference 服务、具有大量非 matmul 计算的工作负载。
- 将 TPUs 用于:大规模 JAX training(数千个芯片)、Google Cloud 上对成本敏感的 training、以矩阵乘法为主的工作负载。
选择正确的工具¶
| 工作量 | 最佳工具 | 为什么 |
|---|---|---|
| ML training(PyTorch) | NVIDIA GPU + CUDA/Triton | 最大的生态系统,最好的工具 |
| ML training(JAX,大型) | TPU 或 NVIDIA GPU | TPU 用于 Google 规模的成本,GPU 用于灵活性 |
| 定制熔断器 kernels | Triton (Python) 或 CUDA C | Triton 实现开发速度,CUDA 实现峰值性能 |
| JAX 定制 kernels | Pallas | TPU 的唯一选项,也适用于 GPU |
| 跨平台inference | Vulkan(文件 07)或 ONNX 运行时 | 可在任何 GPU 供应商上运行 |
| 移动/边缘 inference | 金属(苹果)、Vulkan(安卓)、NNAPI | 特定于平台的加速器 |
| 浏览器 inference | WebGPU(文件07) | 浏览器中唯一的选项 |
| CPU-仅inference | ONNX 运行时 + AVX/NEON | 不需要 GPU,使用 SIMD(文件 02-03) |
| 新颖的硬件 | 特定于供应商的 SDK | 每个加速器都有自己的工具链 |
编码任务(使用带有 GPU 运行时的 CoLab)¶
-
编写并运行 Triton kernel 以进行 vector 加法。将其性能与 PyTorch 的内置附加功能进行比较。
import triton import triton.language as tl import torch import time @triton.jit def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < n x = tl.load(x_ptr + offs, mask=mask) y = tl.load(y_ptr + offs, mask=mask) tl.store(out_ptr + offs, x + y, mask=mask) n = 10_000_000 x = torch.randn(n, device='cuda') y = torch.randn(n, device='cuda') # Triton out_triton = torch.empty_like(x) grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),) add_kernel[grid](x, y, out_triton, n, BLOCK=1024) # PyTorch out_torch = x + y # Verify correctness assert torch.allclose(out_triton, out_torch, atol=1e-5) # Benchmark torch.cuda.synchronize() start = time.time() for _ in range(1000): add_kernel[grid](x, y, out_triton, n, BLOCK=1024) torch.cuda.synchronize() triton_time = (time.time() - start) / 1000 start = time.time() for _ in range(1000): out_torch = x + y torch.cuda.synchronize() torch_time = (time.time() - start) / 1000 print(f"Triton: {triton_time*1000:.3f} ms") print(f"PyTorch: {torch_time*1000:.3f} ms") print(f"Ratio: {torch_time/triton_time:.2f}x") -
编写一个 Triton 融合的 kernel,它在一次传递中执行乘法+加法+ReLU。与三个单独的 PyTorch 操作进行比较。
import triton import triton.language as tl import torch import time @triton.jit def fused_mul_add_relu_kernel(x_ptr, w_ptr, b_ptr, out_ptr, n, BLOCK: tl.constexpr): pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < n x = tl.load(x_ptr + offs, mask=mask) w = tl.load(w_ptr + offs, mask=mask) b = tl.load(b_ptr + offs, mask=mask) result = tl.maximum(x * w + b, 0.0) # fused: mul + add + relu tl.store(out_ptr + offs, result, mask=mask) n = 10_000_000 x = torch.randn(n, device='cuda') w = torch.randn(n, device='cuda') b = torch.randn(n, device='cuda') # Fused (Triton) out_fused = torch.empty_like(x) grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),) fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024) # Unfused (PyTorch) out_unfused = torch.relu(x * w + b) assert torch.allclose(out_fused, out_unfused, atol=1e-5) # Benchmark torch.cuda.synchronize() start = time.time() for _ in range(1000): fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024) torch.cuda.synchronize() fused_time = (time.time() - start) / 1000 start = time.time() for _ in range(1000): out_unfused = torch.relu(x * w + b) torch.cuda.synchronize() unfused_time = (time.time() - start) / 1000 print(f"Fused (Triton): {fused_time*1000:.3f} ms") print(f"Unfused (PyTorch): {unfused_time*1000:.3f} ms") print(f"Speedup: {unfused_time/fused_time:.2f}x") -
测量 JAX 的 XLA 编译器如何自动融合操作。比较有和没有 jit 的操作链。
import jax import jax.numpy as jnp import time def chain_ops(x): x = x * 2.0 x = x + 1.0 x = jnp.maximum(x, 0.0) # ReLU x = x / jnp.sum(x) return x chain_jit = jax.jit(chain_ops) x = jax.random.normal(jax.random.PRNGKey(0), (10000, 1000)) # Warm up _ = chain_jit(x) jax.block_until_ready(_) # Eager (each op is a separate kernel launch) start = time.time() for _ in range(100): y = chain_ops(x) jax.block_until_ready(y) eager_time = (time.time() - start) / 100 # JIT (XLA fuses operations) start = time.time() for _ in range(100): y = chain_jit(x) jax.block_until_ready(y) jit_time = (time.time() - start) / 100 print(f"Eager: {eager_time*1000:.2f} ms") print(f"JIT: {jit_time*1000:.2f} ms") print(f"Speedup: {eager_time/jit_time:.1f}x (XLA fuses the 4 operations into 1 kernel)")