测试和质量保证¶
测试是您了解代码是否有效的方法,不仅是现在,而且是在每次更改之后。该文件涵盖了 test pyramid、unit tests 和 pytest、mocking、测试 ML 特定代码、CI/CD 管道、linting、formatting 和 code review,以及在 bug 进入生产之前捕获 bug 的实践。
-
众所周知,机器学习代码的测试不足。 “它能training,所以它有效”是普遍的态度。这会导致无声错误:数据加载器错误地洗牌、带有符号错误的损失函数、丢弃 5% 数据的预处理步骤。这些错误不会使您的程序崩溃。它们只会让您的 model 悄然变得更糟,并且您会浪费数周时间来调试“应该更高”的指标。
-
测试不是开销。这是在不破坏事物的情况下快速移动的最快方法。
测试金字塔¶
-
测试是分层组织的,从快速和狭窄到缓慢和广泛:
-
单元测试(基础):单独测试各个函数和类。快(毫秒)、多(数百到数千)。 “
normalise_image会产生 [0, 1] 中的值吗?” -
集成测试(中):测试组件是否协同工作。较慢(秒)。 “数据加载器是否以 model 期望的格式生成批次?”
-
端到端测试(上):测试从输入到输出的完整管道。慢(分钟)。 “
python train.py --config test.yaml是否完成且没有错误并生成有效的 checkpoint?”
-
-
金字塔形状的意思是:写很多unit tests,写很少的integration tests,写少量的end-to-end tests。单元测试可以捕获大多数错误并在几秒钟内运行。端到端测试可以发现集成问题,但速度缓慢且脆弱。
使用 pytest 进行单元测试¶
- pytest是标准的Python测试框架。测试是在以
test_开头的文件中以test_开头的函数:
# tests/test_utils.py
def test_normalise_image():
import numpy as np
image = np.array([0, 128, 255], dtype=np.uint8)
result = normalise_image(image, mean=128, std=128)
assert result.min() >= -1.0
assert result.max() <= 1.0
assert abs(result[1]) < 1e-6 # 128 normalised by mean=128 should be ~0
def test_normalise_empty():
import numpy as np
image = np.array([], dtype=np.uint8)
result = normalise_image(image, mean=128, std=128)
assert len(result) == 0
pytest tests/ # run all tests
pytest tests/test_utils.py # run one file
pytest -v # verbose output
pytest -x # stop on first failure
pytest -k "normalise" # run tests matching name pattern
pytest --tb=short # shorter tracebacks
固定装置¶
- 夹具为测试提供可重复使用的设置。不要在每个测试中重复设置代码,而是定义一次:
import pytest
@pytest.fixture
def sample_dataset():
"""Create a small dataset for testing."""
return {
"inputs": torch.randn(10, 3, 32, 32),
"labels": torch.randint(0, 10, (10,))
}
@pytest.fixture
def trained_model():
"""Load a small pretrained model."""
model = SmallModel()
model.load_state_dict(torch.load("tests/fixtures/small_model.pt"))
return model
def test_model_output_shape(trained_model, sample_dataset):
output = trained_model(sample_dataset["inputs"])
assert output.shape == (10, 10) # batch_size x num_classes
- 夹具可以有范围:
scope="function"(默认,每次测试新鲜),scope="module"(每个文件一次),scope="session"(每次测试运行一次)。使用scope="session"进行昂贵的设置,例如加载 model。
参数化测试¶
- 使用多个输入测试同一函数,无需重复代码:
@pytest.mark.parametrize("input,expected", [
([1, 2, 3], 6),
([], 0),
([-1, 1], 0),
([1000000, 1000000], 2000000),
])
def test_sum(input, expected):
assert sum(input) == expected
模拟和修补¶
- 模拟在测试过程中用假依赖替换真正的依赖。这使您可以单独测试函数,而无需数据库、API 或 GPU。
from unittest.mock import patch, MagicMock
def test_training_logs_metrics():
mock_logger = MagicMock()
with patch("my_project.training.trainer.wandb") as mock_wandb:
trainer = Trainer(logger=mock_logger)
trainer.train_one_epoch()
# verify that the trainer logged metrics
mock_logger.log.assert_called()
# verify it logged a loss value
call_args = mock_logger.log.call_args
assert "loss" in call_args[1]
-
何时模拟:外部服务(API、数据库、云存储)、昂贵的操作(GPU 计算、大文件 I/O)和非确定性行为(随机数生成器、时间戳)。
-
何时不进行模拟:您自己的代码。如果您模拟所有内容,您的测试将验证模拟的行为是否符合预期,而不是您的代码是否有效。模拟边界,直接测试你的逻辑。
测试机器学习代码¶
- ML 代码具有独特的测试挑战:输出是概率性的,training 很慢,并且“正确”是模糊的。
确定性种子¶
- 到处设置随机种子以使测试可重现:
import random
import numpy as np
import torch
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
数值公差¶
- 浮点比较需要容差(IEEE 754 第 13 章):
# BAD: exact comparison fails due to floating point
assert model_output == 0.5
# GOOD: approximate comparison
import numpy as np
assert np.isclose(model_output, 0.5, atol=1e-5)
# For tensors
assert torch.allclose(output, expected, atol=1e-4)
在 ML 中测试什么¶
- 形状测试:验证输出是否具有预期尺寸。
def test_model_output_shape():
model = MyModel(d_model=256, n_classes=10)
x = torch.randn(8, 32, 256) # batch=8, seq=32, dim=256
output = model(x)
assert output.shape == (8, 10)
- 梯度流:验证可training参数的梯度是否非零。
def test_gradients_flow():
model = MyModel()
x = torch.randn(4, 3, 32, 32)
y = torch.randint(0, 10, (4,))
output = model(x)
loss = F.cross_entropy(output, y)
loss.backward()
for name, param in model.named_parameters():
assert param.grad is not None, f"No gradient for {name}"
assert param.grad.abs().sum() > 0, f"Zero gradient for {name}"
- 一批上的过度拟合:model 应该能够记住单个批次。如果不能,那就有根本性的错误。
def test_overfit_one_batch():
model = MyModel()
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
x, y = get_single_batch()
for _ in range(100):
loss = F.cross_entropy(model(x), y)
loss.backward()
optimiser.step()
optimiser.zero_grad()
assert loss.item() < 0.01, f"Cannot overfit one batch: loss={loss.item()}"
- 数据验证:验证数据加载产生有效的输出。
def test_dataset_basics():
dataset = MyDataset("tests/fixtures/small_data.csv")
assert len(dataset) > 0
x, y = dataset[0]
assert x.shape == (3, 224, 224)
assert 0 <= y < 10
assert not torch.isnan(x).any()
assert not torch.isinf(x).any()
- 决定论:相同的输入+相同的种子→相同的输出。
def test_determinism():
set_seed(42)
output1 = model(input_data)
set_seed(42)
output2 = model(input_data)
assert torch.allclose(output1, output2)
CI/CD 管道¶
-
持续集成 (CI):在每个 commit 或 PR 上自动运行测试。如果测试失败,则 PR 无法merge。这可以防止损坏的代码到达
main。 -
GitHub 操作示例 (
.github/workflows/ci.yml):
name: CI
on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- run: pip install -e ".[dev]"
- run: ruff check src/
- run: mypy src/
- run: pytest tests/ -v --tb=short
- commit 之前的钩子:在每个 commit(本地)之前运行检查,在问题到达 CI 之前捕获问题:
# .pre-commit-config.yaml
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
检测和格式化¶
-
Linting 无需运行代码即可捕获错误和样式问题。 格式化自动强制执行一致的样式。
-
Ruff:一种快速的 Python linter 和格式化程序(在一个工具中取代 flake8、isort 和 black):
- mypy:Python 的静态类型检查器。在运行前捕获类型错误:
mypy src/
# src/model.py:42: error: Argument 1 to "forward" has incompatible type "int"; expected "Tensor"
- 类型提示使代码自我记录并捕获错误:
def train(
model: nn.Module,
dataloader: DataLoader,
optimiser: torch.optim.Optimizer,
num_epochs: int = 10,
) -> float:
"""Train model and return final loss."""
...
代码审查最佳实践¶
-
致作者:
- 在请求审核之前自我审核您的差异。你会发现明显的问题。
- 保持 PR 小而集中。 PR 有一个问题。
- 写出清晰的描述:测试什么、为什么、如何测试。
- 回复每条评论(即使只是“完成”)。
-
对于审稿人:
- 友善一点。批评代码,而不是人。 “这可能会更清楚”而不是“这令人困惑”。
- 区分阻塞问题(错误、安全)和建议(样式、命名)。使用标签:“nit:”,“建议:”,“阻止:”。
- 提出问题而不是提出要求。 “如果这个列表是空的,会发生什么?”比“处理空箱子”更有帮助。
- 及时批准。 PR 等待审核的天数 blocks 作者并鼓励大型、批量的 PR(这更难审核)。