Skip to content

测试和质量保证

测试是您了解代码是否有效的方法,不仅是现在,而且是在每次更改之后。该文件涵盖了 test pyramid、unit tests 和 pytest、mocking、测试 ML 特定代码、CI/CD 管道、linting、formatting 和 code review,以及在 bug 进入生产之前捕获 bug 的实践。

  • 众所周知,机器学习代码的测试不足。 “它能training,所以它有效”是普遍的态度。这会导致无声错误:数据加载器错误地洗牌、带有符号错误的损失函数、丢弃 5% 数据的预处理步骤。这些错误不会使您的程序崩溃。它们只会让您的 model 悄然变得更糟,并且您会浪费数周时间来调试“应该更高”的指标。

  • 测试不是开销。这是在不破坏事物的情况下快速移动的最快方法。

测试金字塔

  • 测试是分层组织的,从快速和狭窄到缓慢和广泛:

    • 单元测试(基础):单独测试各个函数和类。快(毫秒)、多(数百到数千)。 “normalise_image 会产生 [0, 1] 中的值吗?”

    • 集成测试(中):测试组件是否协同工作。较慢(秒)。 “数据加载器是否以 model 期望的格式生成批次?”

    • 端到端测试(上):测试从输入到输出的完整管道。慢(分钟)。 “python train.py --config test.yaml 是否完成且没有错误并生成有效的 checkpoint?”

  • 金字塔形状的意思是:写很多unit tests,写很少的integration tests,写少量的end-to-end tests。单元测试可以捕获大多数错误并在几秒钟内运行。端到端测试可以发现集成问题,但速度缓慢且脆弱。

使用 pytest 进行单元测试

  • pytest是标准的Python测试框架。测试是在以 test_ 开头的文件中以 test_ 开头的函数:
# tests/test_utils.py

def test_normalise_image():
    import numpy as np
    image = np.array([0, 128, 255], dtype=np.uint8)
    result = normalise_image(image, mean=128, std=128)
    assert result.min() >= -1.0
    assert result.max() <= 1.0
    assert abs(result[1]) < 1e-6  # 128 normalised by mean=128 should be ~0

def test_normalise_empty():
    import numpy as np
    image = np.array([], dtype=np.uint8)
    result = normalise_image(image, mean=128, std=128)
    assert len(result) == 0
pytest tests/                     # run all tests
pytest tests/test_utils.py        # run one file
pytest -v                         # verbose output
pytest -x                         # stop on first failure
pytest -k "normalise"             # run tests matching name pattern
pytest --tb=short                 # shorter tracebacks

固定装置

  • 夹具为测试提供可重复使用的设置。不要在每个测试中重复设置代码,而是定义一次:
import pytest

@pytest.fixture
def sample_dataset():
    """Create a small dataset for testing."""
    return {
        "inputs": torch.randn(10, 3, 32, 32),
        "labels": torch.randint(0, 10, (10,))
    }

@pytest.fixture
def trained_model():
    """Load a small pretrained model."""
    model = SmallModel()
    model.load_state_dict(torch.load("tests/fixtures/small_model.pt"))
    return model

def test_model_output_shape(trained_model, sample_dataset):
    output = trained_model(sample_dataset["inputs"])
    assert output.shape == (10, 10)  # batch_size x num_classes
  • 夹具可以有范围scope="function"(默认,每次测试新鲜),scope="module"(每个文件一次),scope="session"(每次测试运行一次)。使用 scope="session" 进行昂贵的设置,例如加载 model。

参数化测试

  • 使用多个输入测试同一函数,无需重复代码:
@pytest.mark.parametrize("input,expected", [
    ([1, 2, 3], 6),
    ([], 0),
    ([-1, 1], 0),
    ([1000000, 1000000], 2000000),
])
def test_sum(input, expected):
    assert sum(input) == expected

模拟和修补

  • 模拟在测试过程中用假依赖替换真正的依赖。这使您可以单独测试函数,而无需数据库、API 或 GPU。
from unittest.mock import patch, MagicMock

def test_training_logs_metrics():
    mock_logger = MagicMock()

    with patch("my_project.training.trainer.wandb") as mock_wandb:
        trainer = Trainer(logger=mock_logger)
        trainer.train_one_epoch()

        # verify that the trainer logged metrics
        mock_logger.log.assert_called()
        # verify it logged a loss value
        call_args = mock_logger.log.call_args
        assert "loss" in call_args[1]
  • 何时模拟:外部服务(API、数据库、云存储)、昂贵的操作(GPU 计算、大文件 I/O)和非确定性行为(随机数生成器、时间戳)。

  • 何时不进行模拟:您自己的代码。如果您模拟所有内容,您的测试将验证模拟的行为是否符合预期,而不是您的代码是否有效。模拟边界,直接测试你的逻辑。

测试机器学习代码

  • ML 代码具有独特的测试挑战:输出是概率性的,training 很慢,并且“正确”是模糊的。

确定性种子

  • 到处设置随机种子以使测试可重现:
import random
import numpy as np
import torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

数值公差

  • 浮点比较需要容差(IEEE 754 第 13 章):
# BAD: exact comparison fails due to floating point
assert model_output == 0.5

# GOOD: approximate comparison
import numpy as np
assert np.isclose(model_output, 0.5, atol=1e-5)

# For tensors
assert torch.allclose(output, expected, atol=1e-4)

在 ML 中测试什么

  • 形状测试:验证输出是否具有预期尺寸。
def test_model_output_shape():
    model = MyModel(d_model=256, n_classes=10)
    x = torch.randn(8, 32, 256)  # batch=8, seq=32, dim=256
    output = model(x)
    assert output.shape == (8, 10)
  • 梯度流:验证可training参数的梯度是否非零。
def test_gradients_flow():
    model = MyModel()
    x = torch.randn(4, 3, 32, 32)
    y = torch.randint(0, 10, (4,))

    output = model(x)
    loss = F.cross_entropy(output, y)
    loss.backward()

    for name, param in model.named_parameters():
        assert param.grad is not None, f"No gradient for {name}"
        assert param.grad.abs().sum() > 0, f"Zero gradient for {name}"
  • 一批上的过度拟合:model 应该能够记住单个批次。如果不能,那就有根本性的错误。
def test_overfit_one_batch():
    model = MyModel()
    optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
    x, y = get_single_batch()

    for _ in range(100):
        loss = F.cross_entropy(model(x), y)
        loss.backward()
        optimiser.step()
        optimiser.zero_grad()

    assert loss.item() < 0.01, f"Cannot overfit one batch: loss={loss.item()}"
  • 数据验证:验证数据加载产生有效的输出。
def test_dataset_basics():
    dataset = MyDataset("tests/fixtures/small_data.csv")
    assert len(dataset) > 0
    x, y = dataset[0]
    assert x.shape == (3, 224, 224)
    assert 0 <= y < 10
    assert not torch.isnan(x).any()
    assert not torch.isinf(x).any()
  • 决定论:相同的输入+相同的种子→相同的输出。
def test_determinism():
    set_seed(42)
    output1 = model(input_data)
    set_seed(42)
    output2 = model(input_data)
    assert torch.allclose(output1, output2)

CI/CD 管道

  • 持续集成 (CI):在每个 commit 或 PR 上自动运行测试。如果测试失败,则 PR 无法merge。这可以防止损坏的代码到达 main

  • GitHub 操作示例 (.github/workflows/ci.yml):

name: CI
on: [push, pull_request]

jobs:
  test:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: "3.11"
      - run: pip install -e ".[dev]"
      - run: ruff check src/
      - run: mypy src/
      - run: pytest tests/ -v --tb=short
  • commit 之前的钩子:在每个 commit(本地)之前运行检查,在问题到达 CI 之前捕获问题:
# .pre-commit-config.yaml
repos:
  - repo: https://github.com/astral-sh/ruff-pre-commit
    rev: v0.3.0
    hooks:
      - id: ruff
        args: [--fix]
      - id: ruff-format
  - repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v4.5.0
    hooks:
      - id: trailing-whitespace
      - id: end-of-file-fixer
      - id: check-yaml
pip install pre-commit
pre-commit install    # now hooks run on every git commit

检测和格式化

  • Linting 无需运行代码即可捕获错误和样式问题。 格式化自动强制执行一致的样式。

  • Ruff:一种快速的 Python linter 和格式化程序(在一个工具中取代 flake8、isort 和 black):

ruff check src/          # lint
ruff check --fix src/    # lint and auto-fix
ruff format src/         # format
  • mypy:Python 的静态类型检查器。在运行前捕获类型错误:
mypy src/
# src/model.py:42: error: Argument 1 to "forward" has incompatible type "int"; expected "Tensor"
  • 类型提示使代码自我记录并捕获错误:
def train(
    model: nn.Module,
    dataloader: DataLoader,
    optimiser: torch.optim.Optimizer,
    num_epochs: int = 10,
) -> float:
    """Train model and return final loss."""
    ...

代码审查最佳实践

  • 致作者

    • 在请求审核之前自我审核您的差异。你会发现明显的问题。
    • 保持 PR 小而集中。 PR 有一个问题。
    • 写出清晰的描述:测试什么、为什么、如何测试。
    • 回复每条评论(即使只是“完成”)。
  • 对于审稿人

    • 友善一点。批评代码,而不是人。 “这可能会更清楚”而不是“这令人困惑”。
    • 区分阻塞问题(错误、安全)和建议(样式、命名)。使用标签:“nit:”,“建议:”,“阻止:”。
    • 提出问题而不是提出要求。 “如果这个列表是空的,会发生什么?”比“处理空箱子”更有帮助。
    • 及时批准。 PR 等待审核的天数 blocks 作者并鼓励大型、批量的 PR(这更难审核)。