代码库设计和模式¶
良好的 codebase 设计是将研究原型与生产软件区分开来的。该文件涵盖项目结构、clean code 原理、与 ML 相关的 design patterns、configuration management、logging、API 设计和包装
-
大多数 ML 代码都是从 Jupyter 笔记本开始的。笔记本不断增长、被复制、修改、共享,最终变成一堆无法维护的全局变量、死亡单元和幻数。 代码库设计是组织代码的学科,以便随着项目的发展它仍然是可理解和可修改的。
-
这并不是为了遵守规则而遵守规则。它是为了减少“我想改变 X”和“X 已改变并且正在工作”之间的时间。在精心设计的 codebase 中,这个时间只有几分钟。在设计不佳的情况下,这是通过无证意大利面条进行考古的日子。
项目结构¶
- 一致的项目布局可以让任何人(包括未来的您)立即导航 codebase。
my_project/
├── src/my_project/ # source code (importable package)
│ ├── __init__.py
│ ├── data/ # data loading and preprocessing
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ └── transforms.py
│ ├── models/ # model architectures
│ │ ├── __init__.py
│ │ ├── transformer.py
│ │ └── layers.py
│ ├── training/ # training loops, optimisers
│ │ ├── __init__.py
│ │ ├── trainer.py
│ │ └── losses.py
│ └── utils/ # shared utilities
│ ├── __init__.py
│ └── logging.py
├── configs/ # configuration files
│ ├── base.yaml
│ └── experiment_1.yaml
├── scripts/ # entry points (train, evaluate, serve)
│ ├── train.py
│ ├── evaluate.py
│ └── serve.py
├── tests/ # test files (mirrors src/ structure)
│ ├── test_dataset.py
│ ├── test_model.py
│ └── test_trainer.py
├── notebooks/ # exploration only (not production code)
├── pyproject.toml # project metadata and dependencies
├── README.md
├── .gitignore
└── Dockerfile
-
src/布局:将源代码放在src/my_project/下可以防止从当前目录意外导入(这会掩盖生产中出现的导入错误)。与pip install -e .一起安装进行开发。 -
Monorepo 与 multi-repo:monorepo 将所有相关项目保留在一个 repository 中(更轻松的跨项目更改,共享 CI)。 multi-repo 为每个项目提供了自己的 repository(更清晰的边界,独立的version control)。大多数 ML 团队都是从 monorepo 开始,然后根据需要进行拆分。
-
scripts与 library:将入口点(
train.py、evaluate.py)保留在scripts/中。将可重用逻辑保留在src/中。 training script 应该约为 50 行:解析 config、构建 dataset、构建 model、构建training器、training。所有的复杂性都集中在 library 中。
干净代码原则¶
- 命名:您可以做的最有影响力的事情。名为
x的变量需要您阅读周围的代码才能理解它。名为learning_rate的变量是自记录的。
# BAD
def proc(d, n, lr):
for i in range(n):
for k, v in d.items():
v -= lr * g[k]
# GOOD
def update_parameters(parameters, num_steps, learning_rate):
for step in range(num_steps):
for name, param in parameters.items():
param -= learning_rate * gradients[name]
-
Single Responsibility Principle:每个函数/类只做一件事。名为
load_data_and_train_model的函数正在做两件事,应该分开。这使得每个部分都可以独立测试、可重用且易于理解。 -
DRY(不要重复自己) - 但不要过早。如果将代码复制粘贴三次,请将其提取到函数中。但不要为只使用过一次的代码创建抽象。过早的抽象比重复更糟糕:它增加了复杂性,但没有经过证实的好处。
# Premature abstraction (one use case, over-engineered)
class AbstractDataTransformPipelineFactory:
...
# Just right (direct, clear, used in three places)
def normalise_image(image, mean, std):
return (image - mean) / std
- 幻数:切勿使用未解释的文字值。
# BAD
if len(batch) > 32:
split_batch(batch, 32)
# GOOD
MAX_BATCH_SIZE = 32
if len(batch) > MAX_BATCH_SIZE:
split_batch(batch, MAX_BATCH_SIZE)
- 函数应该简短:如果一个函数无法在一个屏幕上显示(~30 行),那么它可能做得太多了。将逻辑block提取到具有描述性名称的辅助函数中。函数体读起来就像一个高级摘要。
机器学习的设计模式¶
-
设计模式是常见问题的可重用解决方案。这些是与 ML 代码库最相关的代码:
-
工厂模式:创建对象而不指定确切的类。当您的 config 为
model: "transformer"并且您需要实例化正确的类时很有用:
MODEL_REGISTRY = {
"transformer": TransformerModel,
"cnn": CNNModel,
"mlp": MLPModel,
}
def build_model(config):
model_cls = MODEL_REGISTRY[config["model"]]
return model_cls(**config["model_params"])
-
这将 training script 与特定的 model 实现解耦。添加新的model意味着在注册表中添加一行,而不是修改training循环。
-
策略模式:在运行时交换算法。对于损失、optimizer、调度器有用:
LOSS_FUNCTIONS = {
"mse": nn.MSELoss,
"cross_entropy": nn.CrossEntropyLoss,
"focal": FocalLoss,
}
loss_fn = LOSS_FUNCTIONS[config["loss"]]()
- 观察者模式(回调/挂钩):让模block对事件做出反应,而无需紧密耦合。training框架(PyTorch Lightning、Keras)广泛使用回调:
class EarlyStopping:
def __init__(self, patience=5):
self.patience = patience
self.best_loss = float('inf')
self.counter = 0
def on_epoch_end(self, epoch, val_loss):
if val_loss < self.best_loss:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
return "stop"
- 依赖注入:将依赖项传递到函数/类中,而不是在内部创建它们。这使得测试变得容易(注入模拟)并且 configuration 灵活:
# BAD: hard-coded dependency
class Trainer:
def __init__(self):
self.logger = WandbLogger() # cannot test without W&B
# GOOD: injected dependency
class Trainer:
def __init__(self, logger):
self.logger = logger # can inject any logger, including a mock
配置管理¶
-
硬编码超参数、文件路径和 model 设置使实验无法重现且修改非常痛苦。 将 configuration 外部化到文件中。
-
YAML 是 ML 配置最常见的格式:
# configs/experiment_1.yaml
model:
name: transformer
d_model: 512
n_heads: 8
n_layers: 6
training:
batch_size: 64
learning_rate: 3e-4
max_epochs: 100
early_stopping_patience: 10
data:
train_path: /data/train.parquet
val_path: /data/val.parquet
max_seq_length: 512
-
Hydra (Facebook) 是一个 configuration 框架,支持组合(merge 基础 config 以及特定于实验的覆盖)、命令行覆盖 (
python train.py training.lr=1e-3) 和多次运行(扫描超参数)。 -
argparse 对于 scripts 来说更简单,有几个参数:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--config", type=str, default="configs/base.yaml")
args = parser.parse_args()
- 最佳实践:拥有一个包含所有默认值的基本 config,以及仅覆盖更改内容的每个实验配置。跟踪每个实验的 config 及其结果。
日志记录和可观察性¶
print语句用于调试。 日志记录用于生产:
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.debug("Batch loaded: %d samples", len(batch)) # noisy, for debugging
logger.info("Epoch %d: loss=%.4f, lr=%.6f", epoch, loss, lr) # normal operation
logger.warning("GPU memory >90%%, consider reducing batch size")
logger.error("Failed to load checkpoint: %s", path) # recoverable error
logger.critical("CUDA out of memory, aborting") # fatal
-
为什么不打印:logging 支持级别(过滤掉生产中的调试消息)、formatting(时间戳、模block名称)和处理程序(写入文件、发送到监控系统),而无需更改 logging 调用。
-
结构化 logging 输出机器可解析的格式 (JSON) 以及人类可读的消息。这可以实现对特定字段的搜索和警报:
API设计¶
-
如果您的 model 将由其他服务(Web 应用程序、移动应用程序、另一个 ML 管道)使用,则它需要 API(应用程序编程接口)。
-
REST API 使用 HTTP 方法:
GET读取、POST创建/预测、PUT更新、DELETE删除。端点遵循基于资源的命名:
POST /api/v1/predict # send input, get prediction
GET /api/v1/models # list available models
GET /api/v1/models/{id} # get model details
POST /api/v1/models/{id}/predict # predict with a specific model
- FastAPI 是机器学习服务的首选 Python 框架:
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
class PredictRequest(BaseModel):
text: str
class PredictResponse(BaseModel):
label: str
confidence: float
@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
result = model.predict(request.text)
return PredictResponse(label=result.label, confidence=result.score)
-
FastAPI 自动生成 API 文档(
/docs的 Swagger UI),使用 Pydantic models 验证输入/输出,并支持异步以实现高吞吐量。 -
对于内部服务到服务的通信,gRPC 比 REST 更快。它使用 Protocol Buffers(二进制序列化,比 JSON 更小、更快)并支持流式传输。由 TensorFlow Serving、Triton inference服务器和许多微服务架构使用。
包装和分销¶
- 使您的代码可作为包安装,以便其他人(以及您自己的 scripts)干净地导入它:
# pyproject.toml
[project]
name = "my-ml-project"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = [
"torch>=2.0",
"jax>=0.4",
"pydantic>=2.0",
]
[project.optional-dependencies]
dev = ["pytest", "ruff", "mypy"]
[build-system]
requires = ["setuptools>=64"]
build-backend = "setuptools.backends._legacy:_Backend"
-
可编辑安装 (
-e):对源代码的更改会立即反映出来,无需重新安装。开发过程中必不可少。 -
固定依赖项:具有精确版本的
requirements.txt(torch==2.2.1,而不是torch>=2.0)可确保可重复性。使用pip freeze > requirements.txt捕获您当前的环境。对于更复杂的依赖关系管理,请使用uv、poetry或pip-tools。
使用 AI 编码代理¶
-
AI 编码代理(Claude Code、GitHub Copilot、Cursor 等)现已成为专业工程工作流程的一部分。使用得当,它们可以极大地加速开发。如果使用不当,它们会引入微妙的错误,削弱您对自己的 codebase 的理解,并造成生产力的错误感觉。
-
正确的心态 model:代理是速度快但缺乏经验的结对程序员。它可以快速编写代码,了解语法和标准模式,并且阅读的文档比您以往任何时候都多。但它不了解您的特定系统、业务限制、边缘情况或设计决策背后的“原因”。你是高级工程师;代理人是初级的。您指导、审查并承担责任。
当代理商表现出色时¶
-
样板和脚手架:生成 Dockerfile、CI 配置、测试 fixtures、数据类定义、argparse 设置。这些遵循众所周知的模式,并且手写起来很乏味。让代理生成它们,然后检查其正确性。
-
编写测试:描述函数的行为,代理生成测试用例。它经常捕获您可能会错过的边缘情况(空输入、负值、Unicode)。始终阅读生成的测试 - 它们验证您的假设,而不仅仅是您的代码。
-
重构:“将此 block 提取到函数中”,“将此类转换为使用数据类”,“向此模block添加类型提示”。意图明确且发生细微错误的风险较低的机械转换。
-
探索和原型设计:“编写一个快速的 script 来基准 inference 延迟”或“向我展示如何使用 HuggingFace 标记器 API”。该代理可以比阅读文档更快地为您提供工作起点。
-
文档和文档字符串:代理可以根据您的代码结构生成文档。检查准确性,但繁重的工作是自动化的。
-
调试帮助:粘贴错误回溯并请求诊断。代理通常可以识别根本原因并建议修复,特别是对于常见问题(形状不匹配、导入错误、CUDA 内存不足)。
何时不依赖代理商¶
-
新颖的架构决策:如果您正在设计新的 training 管道,代理将为您提供通用答案。它不知道您的数据限制、延迟要求或团队专业知识。使用代理来实现您已经想好的设计。
-
安全关键代码:身份验证、加密、输入清理。代理可能会生成看起来正确但存在细微漏洞(SQL 注入、不安全的默认值、定时攻击)的代码。安全代码应由了解威胁 model 的人员编写,并由其他人审核。
-
性能关键的内部循环:代理将编写正确但幼稚的代码。对于 GPU kernels、内存关键型数据结构或延迟敏感的服务路径,您需要了解硬件约束(第 13 章、第 16 章)并进行刻意优化。
-
你不理解的代码:如果代理生成 200 行,而你无法解释每行的作用,请不要 commit 它。您现在正在维护您不理解的代码,当它损坏时(它会损坏),您无法调试它。这是最常见、最危险的故障模式。
审查纪律¶
-
在commit之前,请务必阅读生成代码的每一行。这不是可选的。代理的代码是草稿,不是成品。就像对待同事的 pull request 一样:批判性地审查它。
-
要检查什么:
- 正确性:它确实按照您的要求进行吗?代理解决的问题常常与您想要的问题略有不同。
- 边缘情况:它是否处理空输入、无值、负数、非常大的输入?代理经常忽略边缘情况处理。
- 幻觉的 API:代理可能会调用不存在的函数或使用不存在的参数,特别是对于较新或不太常见的库。验证每个 API 调用都是真实的。
- 过度设计:代理往往会生成不必要的代码。用 50 行解决 10 行问题会增加复杂性,但没有任何好处。无情地简化。
- 安全性:硬编码的秘密、未经消毒的用户输入、不安全的默认设置。代理人并不以敌对的方式思考。
- 风格一致性:生成的代码是否符合您项目的约定(命名、模式、错误处理)?
如何写出好的提示¶
-
座席输出的质量与您的指导质量成正比。模糊的提示会得到模糊的代码。
-
不好:“编写数据加载器”
-
好:“为包含‘text’和‘label’列的 CSV 文件编写一个 PyTorch DataLoader。使用 HuggingFace tokeniser 'bert-base-uncased' 和 max_length=512 对文本进行标记。返回 input_ids、attention_mask 和标签作为张量。通过跳过这些行来处理 CSV 在标签列中缺少值的情况。”
-
提供上下文:告诉代理您的项目结构、现有代码、约束和约定。上下文越多,输出就越好。
-
指定约束:“仅使用标准 library”、“必须与 Python 3.10 配合使用”、“不使用全局变量”、“遵循
src/models/transformer.py中的现有模式”。 -
请求解释:“实施 X 并解释关键设计决策。”这迫使智能体阐明其inference,使您更容易发现有缺陷的假设。
使用质量门来发现代理错误¶
-
您现有的质量基础设施(文件 04)可以捕获代理错误以及人为错误:
- 类型检查(mypy):捕获幻觉的 API 签名和类型不匹配。
- Linting (ruff):捕获未使用的导入、未定义的变量和样式违规。
- 测试 (pytest):如果代理的代码通过了您的测试套件,则它更有可能是正确的。如果您没有测试,请在要求代理实现该功能之前编写测试(测试驱动开发对于代理尤其有效)。
- CI 管道:在每个 commit 上自动运行上述所有内容。
-
“代理编写代码”+“质量门对其进行验证”的组合比单独使用任何一个都更有效率。代理速度快但马虎;门很透彻,但不写代码。总之,您可以获得速度和正确性。
生产力陷阱¶
-
编码代理最大的风险是生产力的错觉。您可以在 10 分钟内生成 500 行代码。但如果你因为不理解而花 2 小时调试这 500 行,那么你比自己在 30 分钟内写 200 行还慢。
-
座席的真正生产力来自于:
- 保持控制:您决定架构,代理填写实现。
- 了解生成的内容:如果无法解释,请重写或要求代理简化。
- 投资质量门:测试、类型和 linting 在每次代理交互中摊销其成本。
- 使用代理来弥补你的弱点:如果你擅长算法但编写测试很慢,那么让代理编写测试。如果您擅长 UI 代码但不熟悉数据库查询,请让代理起草 SQL。发挥你的优势,弥补你的差距。
-
充分利用编码代理的工程师是那些已经知道如何良好编码的工程师。代理人增强您现有的技能;它不会取代它。了解数据结构、算法、系统设计和软件工程(整章)可以让您有效地指导代理并批判性地评估其输出。