Skip to content

Object Detection and Segmentation

目标检测对图像中的每个物体进行定位和分类;segmentation 为每个 pixel 分配标签。本文涵盖 IoU、mAP、anchor box、R-CNN 系列、YOLO、SSD、Feature Pyramid Network、语义/实例/全景 segmentation(U-Net、Mask R-CNN、SAM)以及用于基准测试的指标。

  • 图像分类(文件 02)回答"图像中有什么?"目标检测提出了一个更难的问题:"图像中有哪些物体,它们在哪里?"

  • Segmentation 更进一步:"哪些 pixel 属于哪个物体或类别?"这些任务形成了空间理解精度不断提高的层次结构。

  • 目标检测模型输出一组 bounding box,每个 bounding box 由四个坐标(左上角 \(x, y\),宽度,高度)以及带置信度分数的类别标签定义。一张图像可能包含零个、一个或数百个来自多个类别的物体。

包含多个物体的输入图像,每个物体用带类别标签和置信度分数的彩色 bounding box 框出

  • Intersection over Union(IoU) 衡量预测的 bounding box 与 ground truth 的匹配程度。它是重叠面积除以联合面积:
\[\text{IoU} = \frac{\text{Area of Intersection}}{\text{Area of Union}}\]
  • IoU 为 1 表示完美重叠;IoU 为 0 表示完全不重叠。"正确"检测的标准阈值为 IoU \(\geq 0.5\),但也使用更严格的阈值(0.75、0.9)。

  • 若检测到的 bounding box 与 ground truth box 的 IoU 超过阈值且类别正确,则为 true positive(TP)

  • False positive(FP) 是与任何 ground truth 不匹配的预测 box。

  • False negative(FN) 是没有预测与之匹配的 ground truth 物体。这些是第 06 章中相同的精确率/召回率概念。

  • Average Precision(AP) 汇总一个类别的检测质量。对于每个类别,按置信度分数对所有检测排序,计算每个排名处的精确率和召回率,并计算精确率-召回率曲线下面积:

\[\text{AP} = \int_0^1 p(r) \, dr\]
  • 实际上,曲线是插值的:在每个召回率级别,精确率设为在任何召回率 \(\geq r\) 时的最大精确率。这平滑了曲线并使其单调递减。

  • Mean Average Precision(mAP) 对所有类别取 AP 的平均。"mAP@0.5" 使用 IoU 阈值 0.5。"mAP@[.5:.95]"(COCO 标准)在 0.5 到 0.95 之间以 0.05 为步长对十个 IoU 阈值取 mAP 的平均,同时奖励检测和精确定位。

  • Non-Maximum Suppression(NMS) 去除重复检测。当模型对同一物体预测多个重叠 box 时,NMS 保留置信度最高的 box,并移除所有与其 IoU 超过阈值的其他 box。这在模型产生原始预测后按类别应用。

  • 两阶段检测器首先提出候选区域,然后对每个候选区域进行分类和精化。

  • R-CNN(Girshick 等,2014)是第一个成功的深度学习检测器。它使用选择性搜索(一种经典算法)提出约 2,000 个候选区域,将每个区域裁剪到固定大小,独立地将每个区域通过 CNN 处理,然后用 SVM(第 06 章)进行分类。R-CNN 准确但速度极慢:每张图像需要运行 CNN 2,000 次。

  • Fast R-CNN(Girshick,2015)通过在整个图像上只运行一次 CNN 以产生共享 feature map 来解决冗余问题,然后使用 RoI pooling(感兴趣区域 pooling)从该共享 map 中为每个候选区域提取 feature。

  • RoI pooling 将 feature map 中大小可变的区域划分为网格,并在每个单元内进行 max pooling,从而产生固定大小的输出。这快得多,因为昂贵的 CNN 计算只发生一次。

  • Faster R-CNN(Ren 等,2015)通过引入 Region Proposal Network(RPN) 消除了外部区域提案算法,RPN 是一个在共享 feature map 之上运行的小型 CNN,直接预测候选区域。RPN 在 feature map 上滑动小窗口,在每个位置预测 \(k\) 个候选区域(每个 anchor box 对应一个)。

Faster R-CNN 流程:输入图像 → backbone CNN → 共享 feature map → RPN 生成候选区域 → RoI pooling → 分类和 box 回归头

  • Anchor box 是 feature map 每个空间位置上预定义的 bounding box,覆盖不同的尺度和宽高比(例如,三种尺度 × 三种比例 = 每个位置 9 个 anchor)。RPN 对每个 anchor 预测两件事:目标性分数(目标 vs 背景)以及将 anchor 精化为更紧密候选区域的坐标偏移。这种参数化使回归问题更容易:网络不是预测绝对坐标,而是预测对合理起始 box 的小调整。

  • Anchor 偏移的参数化如下:

\[t_x = \frac{x - x_a}{w_a}, \quad t_y = \frac{y - y_a}{h_a}, \quad t_w = \log\frac{w}{w_a}, \quad t_h = \log\frac{h}{h_a}\]
  • 其中 \((x, y, w, h)\) 是预测 box 的中心和尺寸,\((x_a, y_a, w_a, h_a)\) 是 anchor。对宽度和高度的对数变换确保预测 box 始终为正,并使回归具有尺度不变性。

  • Faster R-CNN 使用多任务损失进行训练:类别标签的分类损失(来自第 05 章的交叉熵),加上 box 回归的 smooth L1 损失。Smooth L1 对异常值的敏感度低于 L2:

\[ \text{smooth}_{L1}(x) = \begin{cases} 0.5x^2 & \text{if } |x| < 1 \\ |x| - 0.5 & \text{otherwise} \end{cases} \]
  • Feature Pyramid Network(FPN)(Lin 等,2017)通过构建带横向连接的自顶向下通路来解决多尺度问题,将高级语义与低级空间细节融合。backbone 在多个尺度上产生 feature map(每个 pooling layer 将分辨率减半)。FPN 增加了一条自顶向下的通路,每个层次接收从上方上采样的 feature,并通过横向 1x1 convolution 与相应的自底向上层次合并。结果是一个 feature map 金字塔,每个层次同时具有强语义和良好的空间分辨率。

  • 从金字塔的高分辨率层次检测小物体;从低分辨率层次检测大物体。FPN 现在是大多数现代检测架构的标准组件。

  • 单阶段检测器完全跳过候选区域步骤,在单次前向传播中预测类别标签和 bounding box。这速度更快,但历史上比两阶段检测器精度更低,直到 focal loss 弥合了这一差距。

  • YOLO(You Only Look Once,Redmon 等,2016)将图像划分为 \(S \times S\) 网格。每个网格单元预测 \(B\) 个 bounding box 和 \(C\) 个类别概率。如果物体的中心落在某个网格单元内,该单元负责检测该物体。YOLO 极快,因为整个检测是无候选区域阶段的单次前向传播。

  • YOLOv2 添加了 anchor box、batch normalisation 和多尺度训练。YOLOv3 使用 Feature Pyramid Network 并在三个尺度上预测。YOLOv4-v8 持续改进,包括更好的 backbone、路径聚合网络以及马赛克 data augmentation(在训练期间将四张图像拼接在一起以增加上下文多样性)。

  • SSD(Single Shot MultiBox Detector,Liu 等,2016)在 backbone 内的多个 feature map 尺度上进行预测,在每个尺度使用 anchor box。早期(高分辨率)feature map 检测小物体;后期(低分辨率)map 检测大物体。SSD 比 Faster R-CNN 更快,准确率也具有竞争力。

  • RetinaNet(Lin 等,2017)识别了单阶段检测器的核心问题:类别不平衡。绝大多数 anchor box 对应背景,产生大量容易的负样本,这些样本主导了损失,并淹没了来自少数正样本的 gradient。

  • Focal loss 通过降低简单样本的权重来解决这个问题:

\[\text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)\]
  • 其中 \(p_t\) 是正确类别的预测概率。当模型置信且正确时(\(p_t\) 高),\((1 - p_t)^\gamma\) 很小,减少了简单负样本的损失贡献。超参数 \(\gamma\)(通常为 2)控制降权的强度。当 \(\gamma = 0\) 时,focal loss 退化为标准交叉熵。使用 focal loss,RetinaNet 以单阶段的速度实现了与两阶段检测器相当的准确率。

  • 无 anchor 检测完全消除 anchor box,减少超参数调整并简化流程。

  • FCOS(Fully Convolutional One-Stage,Tian 等,2019)在 feature map 的每个空间位置预测该位置到最近 bounding box 四条边(左、上、右、下)的距离加上类别标签。中心度分数降低距离物体中心较远的预测的权重,提高质量。FCOS 使用 FPN 处理多个尺度。

  • CenterNet(Zhou 等,2019)将物体检测为点:预测一个热力图,其中峰值对应物体中心,然后在每个峰值处回归宽度和高度。检测变为关键点估计。这种方式简洁且无 anchor,但需要仔细的热力图后处理。

  • CornerNet 将物体检测为角点对(左上角和右下角)。它预测两个热力图(每种角点类型一个),并使用关联嵌入将对应角点匹配成 bounding box。这避免了对 anchor 的需求,并处理任意形状的物体。

  • 语义 segmentation 为图像中的每个 pixel 分配一个类别标签。与检测(输出 box)不同,segmentation 生成密集的逐 pixel 标签图。街道场景可能将每个 pixel 标记为道路、人行道、汽车、行人、建筑物、天空等。

语义 segmentation:输入街道场景及其逐 pixel 标签图,每种颜色代表一个类别

  • Fully Convolutional Network(FCN)(Long 等,2015)通过将全连接 layer 替换为卷积 layer,使分类 CNN 适用于 segmentation,允许网络输出空间图而非单一类别。通过上采样(转置 convolution 或双线性插值)将输出恢复到输入分辨率。来自早期 layer 的 skip connection 补充回了下采样过程中丢失的空间细节。

  • 转置 convolution(有时称为"反卷积")是 convolution 的上采样对应物。带 stride 的 convolution 减小空间维度,而转置 convolution 增大空间维度。它在输入元素之间插入零,然后应用标准 convolution,有效地学习如何上采样。

  • U-Net(Ronneberger 等,2015)引入了在每个层次都有 skip connection 的对称 encoder-decoder 架构。encoder(收缩路径)在增加 channel 的同时降低空间分辨率,与分类 CNN 完全相同。decoder(扩张路径)上采样回全分辨率。Skip connection 在每个层次将 encoder feature map 与 decoder feature map 拼接,为 decoder 提供精细的空间细节。这种高级语义和低级细节的结合产生了清晰、准确的 segmentation 边界。

U-Net 架构:左侧是带下采样的 encoder 路径,右侧是带上采样的 decoder 路径,skip connection 连接对应层次

  • U-Net 最初是为生物医学图像 segmentation 设计的(训练数据稀缺),其架构已成为许多后续模型的基础,包括 latent diffusion model 中的 U-Net(文件 04)。

  • DeepLab(Chen 等,2014-2018)为 segmentation 引入了两个关键创新:

    • Atrous(dilated)convolution:在 filter 元素之间插入间隔的标准 convolution,由 dilation rate \(r\) 控制。dilation 为 \(r\) 的 3x3 filter 具有 \((2r + 1) \times (2r + 1)\) 的感受野,但只使用 9 个参数。这在不下采样的情况下捕获多尺度上下文,保留空间分辨率。

    • Atrous Spatial Pyramid Pooling(ASPP):并行应用多个具有不同 dilation rate 的 atrous convolution(例如,rate 为 1、6、12、18),拼接结果,并用 1x1 convolution 融合。ASPP 同时捕获多个尺度的上下文,精神上类似于 Inception 模块(文件 02),但使用 dilation 而非不同 kernel 尺寸。

  • DeepLab 还使用条件随机场(CRF)(第 05 章)作为后处理步骤,通过鼓励空间上相邻且颜色相似的 pixel 共享相同标签来精化 segmentation 边界。

  • 实例 segmentation 结合了检测和 segmentation:识别每个独立物体实例,并为每个实例产生逐 pixel 掩码。场景中的两辆汽车得到两个独立的掩码,而不仅仅是两者都标为"汽车"。

  • Mask R-CNN(He 等,2017)通过添加一个为每个检测到的物体预测二进制掩码的小型 segmentation 头来扩展 Faster R-CNN。架构为 Faster R-CNN + 掩码分支:掩码分支接收 RoI pooling 后的 feature,并输出每个类别的 \(m \times m\) 二进制掩码。它使用 RoIAlign 替代 RoI pooling:在精确采样点处进行双线性插值,而非量化网格单元,避免了量化引起的空间错位。这一小改动显著提高了掩码质量。

  • Mask R-CNN 以多任务损失训练:分类损失 + box 回归损失 + 掩码损失(逐 pixel 二进制交叉熵)。掩码分支独立预测每个类别的掩码;只使用与预测类别对应的掩码,将掩码预测与分类解耦并同时改进两者。

  • 全景 segmentation 将语义和实例 segmentation 统一为单一任务。每个 pixel 同时获得类别标签(语义)和实例 ID(实例,针对汽车、人等"thing"类别)。"Stuff"类别(天空、道路、草地)只获得语义标签,因为它们是没有可计数实例的无定形区域。

  • 全景质量(PQ)指标通过将其分解为分割质量(匹配段的平均 IoU)和识别质量(匹配段的 F1 分数)来评估:

\[\text{PQ} = \underbrace{\frac{\sum_{(p,g) \in \text{TP}} \text{IoU}(p,g)}{|\text{TP}|}}_{\text{SQ}} \times \underbrace{\frac{|\text{TP}|}{|\text{TP}| + \frac{1}{2}|\text{FP}| + \frac{1}{2}|\text{FN}|}}_{\text{RQ}}\]
  • 实时 segmentation 对于自动驾驶和增强现实等应用至关重要,这些场景的延迟预算很紧(通常每帧低于 30 毫秒)。

  • BiSeNet(Bilateral Segmentation Network,Yu 等,2018)使用两条并行路径:空间路径(宽而浅的 layer 保留空间细节)和上下文路径(深而窄的 layer 捕获语义)。输出融合,同时兼顾速度和准确率。

  • DDRNet(Deep Dual-Resolution Network,Hong 等,2021)在整个网络中维护两个不同分辨率的分支,并在它们之间反复进行信息交换。高分辨率分支保留空间细节,低分辨率分支捕获全局上下文。多个双向融合模块在两个方向上合并信息。

  • 实时 segmentation 的总体趋势是避免沉重的 encoder-decoder 模式,转而在整个网络中保持足够的空间分辨率,以换取大幅更低的延迟为代价牺牲一些准确率。

编程任务(使用 CoLab 或 notebook)

  1. 从零实现 IoU 计算和 Non-Maximum Suppression。将 NMS 应用于一组重叠 bounding box,并将结果可视化。

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    
    def compute_iou(box1, box2):
        """计算两个 box 的 IoU [x1, y1, x2, y2]。"""
        x1 = jnp.maximum(box1[0], box2[0])
        y1 = jnp.maximum(box1[1], box2[1])
        x2 = jnp.minimum(box1[2], box2[2])
        y2 = jnp.minimum(box1[3], box2[3])
    
        intersection = jnp.maximum(0, x2 - x1) * jnp.maximum(0, y2 - y1)
        area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
        area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
        union = area1 + area2 - intersection
    
        return intersection / (union + 1e-6)
    
    def nms(boxes, scores, iou_threshold=0.5):
        """Non-Maximum Suppression。"""
        order = jnp.argsort(-scores)  # 按置信度降序排序
        keep = []
    
        remaining = list(range(len(scores)))
        order_list = order.tolist()
    
        while order_list:
            idx = order_list[0]
            keep.append(idx)
            order_list = order_list[1:]
    
            new_order = []
            for j in order_list:
                iou = compute_iou(boxes[idx], boxes[j])
                if iou < iou_threshold:
                    new_order.append(j)
            order_list = new_order
    
        return keep
    
    # 示例:同一物体的多个重叠检测
    boxes = jnp.array([
        [50, 60, 150, 160],   # 高置信度
        [55, 65, 155, 165],   # 重叠副本
        [52, 58, 148, 158],   # 重叠副本
        [200, 100, 300, 200], # 不同物体
        [205, 105, 305, 205], # 重叠副本
    ])
    scores = jnp.array([0.95, 0.80, 0.70, 0.90, 0.60])
    
    keep = nms(boxes, scores, iou_threshold=0.5)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    colors = ['#3498db', '#e74c3c', '#27ae60', '#9b59b6', '#f39c12']
    
    for ax, title, indices in zip(axes, ['Before NMS', 'After NMS'],
                                   [range(len(boxes)), keep]):
        ax.set_xlim(0, 400); ax.set_ylim(0, 300)
        ax.set_aspect('equal'); ax.invert_yaxis()
        ax.set_title(title)
        for i in indices:
            b = boxes[i]
            rect = patches.Rectangle((b[0], b[1]), b[2]-b[0], b[3]-b[1],
                                      linewidth=2, edgecolor=colors[i],
                                      facecolor='none')
            ax.add_patch(rect)
            ax.text(b[0], b[1]-5, f'{scores[i]:.2f}', color=colors[i], fontsize=10)
    
    plt.tight_layout(); plt.show()
    print(f"Kept {len(keep)} of {len(boxes)} boxes after NMS")
    

  2. 实现简化的 Region Proposal Network(RPN)。给定 feature map,生成多种尺度和宽高比的 anchor box,并预测目标性分数和 box 偏移。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    
    def generate_anchors(feature_h, feature_w, stride, scales, ratios):
        """为 feature map 上的每个位置生成 anchor box。"""
        anchors = []
        for y in range(feature_h):
            for x in range(feature_w):
                cx = (x + 0.5) * stride
                cy = (y + 0.5) * stride
                for s in scales:
                    for r in ratios:
                        w = s * jnp.sqrt(r)
                        h = s / jnp.sqrt(r)
                        anchors.append([cx - w/2, cy - h/2, cx + w/2, cy + h/2])
        return jnp.array(anchors)
    
    def rpn_forward(feature_map, params):
        """简化的 RPN:预测每个 anchor 的目标性分数和 box 偏移。"""
        H, W, C = feature_map.shape
        n_anchors = params['cls_w'].shape[1]
    
        # 在 feature map 上滑动 1x1 conv(简化)
        cls_scores = feature_map.reshape(-1, C) @ params['cls_w']  # (H*W, n_anchors)
        box_offsets = feature_map.reshape(-1, C) @ params['reg_w']  # (H*W, n_anchors*4)
    
        cls_scores = jax.nn.sigmoid(cls_scores)
        return cls_scores.ravel(), box_offsets.reshape(-1, 4)
    
    # 参数设置
    feature_h, feature_w, channels = 4, 4, 16
    stride = 16  # 每个 feature map 单元覆盖 16x16 个 pixel
    scales = [32, 64, 128]
    ratios = [0.5, 1.0, 2.0]
    n_anchors_per_pos = len(scales) * len(ratios)
    
    key = jax.random.PRNGKey(42)
    k1, k2, k3 = jax.random.split(key, 3)
    
    feature_map = jax.random.normal(k1, (feature_h, feature_w, channels))
    params = {
        'cls_w': jax.random.normal(k2, (channels, n_anchors_per_pos)) * 0.01,
        'reg_w': jax.random.normal(k3, (channels, n_anchors_per_pos * 4)) * 0.01,
    }
    
    anchors = generate_anchors(feature_h, feature_w, stride, scales, ratios)
    scores, offsets = rpn_forward(feature_map, params)
    
    print(f"Feature map: {feature_h}x{feature_w}, stride={stride}")
    print(f"Anchors per position: {n_anchors_per_pos}")
    print(f"Total anchors: {len(anchors)}")
    print(f"Objectness scores shape: {scores.shape}")
    print(f"Box offsets shape: {offsets.shape}")
    
    # 可视化一个位置的 anchor
    fig, ax = plt.subplots(figsize=(6, 6))
    img_size = feature_h * stride
    ax.set_xlim(0, img_size); ax.set_ylim(0, img_size)
    ax.invert_yaxis(); ax.set_aspect('equal')
    
    pos_idx = feature_h // 2 * feature_w + feature_w // 2  # 中心位置
    colors = ['#3498db', '#e74c3c', '#27ae60']
    for i, s in enumerate(scales):
        for j, r in enumerate(ratios):
            idx = pos_idx * n_anchors_per_pos + i * len(ratios) + j
            a = anchors[idx]
            rect = patches.Rectangle((a[0], a[1]), a[2]-a[0], a[3]-a[1],
                                      linewidth=1.5, edgecolor=colors[i],
                                      facecolor='none', linestyle=['--', '-', ':'][j])
            ax.add_patch(rect)
    
    ax.scatter([img_size/2], [img_size/2], c='red', s=50, zorder=5)
    ax.set_title(f'Anchors at centre position\n3 scales × 3 ratios = {n_anchors_per_pos}')
    ax.grid(True, alpha=0.3)
    plt.tight_layout(); plt.show()
    

  3. 实现用于一维 segmentation 的简化 U-Net encoder-decoder(对一维信号进行二值标注)。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def conv1d_same(x, kernel):
        """带 same padding 的 1D convolution。"""
        k = len(kernel)
        pad = k // 2
        x_pad = jnp.pad(x, pad, mode='edge')
        n = len(x)
        out = jnp.zeros(n)
        for i in range(n):
            out = out.at[i].set(jnp.sum(x_pad[i:i+k] * kernel))
        return out
    
    def downsample(x):
        return x[::2]
    
    def upsample(x, target_len):
        return jnp.interp(jnp.linspace(0, 1, target_len), jnp.linspace(0, 1, len(x)), x)
    
    def unet_1d(x, params):
        """带 2 个 encoder/decoder 层的简化 1D U-Net。"""
        # Encoder
        e1 = jnp.maximum(0, conv1d_same(x, params['enc1']))
        e1_down = downsample(e1)
    
        e2 = jnp.maximum(0, conv1d_same(e1_down, params['enc2']))
        e2_down = downsample(e2)
    
        # Bottleneck
        bottleneck = jnp.maximum(0, conv1d_same(e2_down, params['bottleneck']))
    
        # Decoder 带 skip connection
        d2_up = upsample(bottleneck, len(e2))
        d2 = jnp.maximum(0, conv1d_same(d2_up + e2, params['dec2']))  # skip connection
    
        d1_up = upsample(d2, len(e1))
        d1 = conv1d_same(d1_up + e1, params['dec1'])  # skip connection
    
        return jax.nn.sigmoid(d1)
    
    # 创建带标注区域的信号
    n = 128
    t = jnp.linspace(0, 4 * jnp.pi, n)
    signal = jnp.sin(t) + 0.5 * jnp.sin(3 * t)
    labels = (signal > 0.5).astype(jnp.float32)  # 二值 segmentation 目标
    
    key = jax.random.PRNGKey(42)
    keys = jax.random.split(key, 5)
    params = {
        'enc1': jax.random.normal(keys[0], (5,)) * 0.3,
        'enc2': jax.random.normal(keys[1], (5,)) * 0.3,
        'bottleneck': jax.random.normal(keys[2], (3,)) * 0.3,
        'dec2': jax.random.normal(keys[3], (5,)) * 0.3,
        'dec1': jax.random.normal(keys[4], (5,)) * 0.3,
    }
    
    def loss_fn(params, signal, labels):
        pred = unet_1d(signal, params)
        return -jnp.mean(labels * jnp.log(pred + 1e-7) + (1 - labels) * jnp.log(1 - pred + 1e-7))
    
    grad_fn = jax.jit(jax.grad(loss_fn))
    lr = 0.05
    
    for step in range(500):
        grads = grad_fn(params, signal, labels)
        params = {k: params[k] - lr * grads[k] for k in params}
    
    pred = unet_1d(signal, params)
    
    fig, axes = plt.subplots(3, 1, figsize=(12, 7), sharex=True)
    axes[0].plot(t, signal, color='#3498db', linewidth=1.5)
    axes[0].set_title('Input Signal'); axes[0].set_ylabel('Value')
    
    axes[1].fill_between(t, 0, labels, alpha=0.3, color='#27ae60')
    axes[1].set_title('Ground Truth Labels'); axes[1].set_ylabel('Label')
    
    axes[2].plot(t, pred, color='#e74c3c', linewidth=1.5)
    axes[2].fill_between(t, 0, (pred > 0.5).astype(float), alpha=0.2, color='#e74c3c')
    axes[2].set_title('U-Net Prediction'); axes[2].set_ylabel('Probability')
    axes[2].set_xlabel('t')
    
    plt.tight_layout(); plt.show()
    print(f"Final loss: {loss_fn(params, signal, labels):.4f}")
    print(f"Pixel accuracy: {jnp.mean((pred > 0.5) == labels):.2%}")