Skip to content

Vision Transformers 与生成

视觉转换器将 self-attention 应用于图像块,通过数据驱动的空间学习挑战 CNN 的主导地位。该文件涵盖 ViT、DeiT、Swin Transformer、使用 GAN (StyleGAN)、VAE 和 diffusion 模型(DDPM,稳定扩散)生成图像,以及超分辨率和神经风格迁移。

  • CNN(文件 02)构建了强大的空间归纳偏差:局部连接、权重共享和平移等变性。 Vision Transformers (ViTs) 提出了一个挑衅性的问题:如果我们完全放弃这些偏差并让模型仅使用第 06 章中的 attention 机制从数据中学习空间结构会怎样?

  • Vision Transformer (ViT)(Dosovitskiy 等人,2021)将标准 Transformer encoder 直接应用于图像。关键思想是将图像视为补丁序列,就像 NLP 将文本视为 tokens 序列一样。

  • 该过程的工作原理如下:

    1. 将图像(高度 \(H\)、宽度 \(W\)、通道 \(C\))分割为大小为 \(P \times P\) 的非重叠补丁网格。这会产生 \(N = HW / P^2\) 补丁。
    2. 将每个补丁展平为长度为 \(P^2 \cdot C\) 的向量,并通过学习的线性 embedding (单个矩阵乘法,第 02 章)将其投影到模型维度 \(D\)
    3. 前置一个可学习的 [CLS] token embedding (类似于 BERT 的 [CLS],第 07 章)。该 token 关注所有补丁,其最终表示用于分类。
    4. 添加 位置 embeddings (每个位置一个可学习向量)以提供空间信息,因为 attention 是排列等变的。
    5. \((N + 1)\) token embeddings 的序列传递给标准 Transformer encoder (multi-head self-attention + FFN,第 06 章)。
    6. [CLS] token 的最终表示通过分类头(一个小的 MLP)。

ViT 管道:将图像分割成 16x16 的块,每个块均进行展平并线性投影,前置 [CLS] 标记,添加位置嵌入,然后由 Transformer 编码器块进行处理

  • 补丁 embedding 相当于内核大小 \(P\) 和步幅 \(P\) (非重叠)的卷积。 ViT 从字面上将 2D 图像转换为 1D 序列,然后使用与语言相同的架构对其进行处理。

  • ViT 比 CNN 具有更少的归纳偏差:它不强制执行局部连接或平移等变性。这意味着它需要更多的训练数据来从头开始学习空间结构。在小型数据集上,CNN 的性能优于 ViT。但是,当在非常大的数据集(JFT-300M,3 亿张图像)上进行训练时,ViT 匹配或超过了最好的 CNN,这表明 CNN 的归纳偏差有助于提高数据效率,但对于最终性能而言并不是必需的。

  • ViT self-attention 是补丁数 \(O(N^2)\)。对于具有 16x16 补丁的 224x224 图像,\(N = 196\),这是可以管理的。但对于更高分辨率的图像或更小的补丁,二次成本变得令人望而却步。

  • DeiT(数据高效图像 Transformer,Touvron 等人,2021)表明,使用强大的数据增强、正则化(随机深度、标签平滑、丢弃)和知识蒸馏,可以仅在 ImageNet 上有效地训练 ViT(无需大量 JFT 数据集):预先训练的 CNN 教师提供 ViT 学生学习匹配的软标签。 DeiT 在 [CLS] token 旁边添加了 蒸馏 token,经过训练以预测教师的输出。

  • Swin Transformer (Liu et al., 2021) 解决了 ViT 的两个主要局限性:其随图像大小二次增长的计算成本以及缺乏分层特征图(检测和分割需要)。

  • Swin 引入 移位窗口: attention 是在本地窗口(例如 7x7 补丁)内计算的,而不是所有补丁上的全局 self-attention 。这使得成本与图像大小呈线性关系:\(O(N)\) 而不是 \(O(N^2)\)。但仅本地窗口就会阻止区域之间的信息流动。

  • 窗口移动解决了这个问题:在交替层中,窗口分区移动窗口大小的一半。这会创建跨窗口连接,允许信息跨层在图像的所有部分之间流动,而无需全局 attention 的成本。

Swin Transformer:l 层计算常规窗口内的注意力,l+1 层将窗口分区移动一半,创建跨窗口连接

  • Swin 还通过跨阶段合并补丁来构建分层表示。在每个阶段之后,相邻的 2x2 块被连接并投影,以使通道尺寸加倍并将空间分辨率减半。这会产生类似于 CNN 和 FPN(文件 03)中的多尺度特征图,使 Swin 直接与 Faster R-CNN 等检测头和 U-Net 等分割头兼容。

  • PVT (Pyramid Vision Transformer) 采用类似的具有空间缩减 attention 的分层方法:在每个阶段,在计算 attention 之前对键和值进行空间下采样,减少二次成本,同时保持全局感受野。

  • 自监督视觉学习训练未标记图像的表示。收集标签的成本很高,但图像却很丰富。目标是学习无需任何人工注释即可很好地转移到下游任务的功能。

  • 对比学习训练模型认识到同一图像的两个增强视图(“正对”)应具有相似的表示,而不同图像的视图(“负对”)应具有不同的表示。

  • SimCLR(Chen 等人,2020)批量创建每个图像的两个增强视图,使用共享主干 + 投影头对两者进行编码,并应用 NT-Xent 损失(归一化温度缩放交叉熵):

\[\ell_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k \neq i} \exp(\text{sim}(z_i, z_k) / \tau)}\]
  • 其中 \(\text{sim}\) 是余弦相似度(第 01 章),\(\tau\) 是温度参数。分子将正对推到一起;分母将负数对分开。 SimCLR 需要大批量(4,096+)才能提供足够的负样本。

  • MoCo(Momentum Contrast,He et al.,2020)通过维护负 embeddings 的动量更新队列来解决大批量要求。查询 encoder 通过梯度下降进行更新;键 encoder 更新为查询 encoder 的指数移动平均线(EMA,第 04 章):\(\theta_k \leftarrow m \theta_k + (1 - m) \theta_q\),带有 \(m = 0.999\)。队列存储最近的密钥 embeddings,提供大量且一致的负数集,而无需大量批量。

  • BYOL(Bootstrap Your Own Latent,Grill et al.,2020)完全消除负对。它使用两个网络:“在线”网络和“目标”网络(在线的 EMA)。在线网络预测目标网络对不同增强视图的表示。在没有负数的情况下,BYOL 通过预测器头部和 EMA 目标的不对称性避免了崩溃问题(模型为所有内容输出相同的向量)。

  • DINO(无标签自蒸馏,Caron 等人,2021)将自蒸馏应用于 ViT。学生网络预测教师网络在不同增强视图上的输出(学生的 EMA)。老师使用较大的作物;学生使用较小的作物。 DINO 生成包含有关场景布局的显式信息的特征:经过 DINO 训练的 ViT 的 self-attention 地图自然地分割对象,无需任何分割监督。

  • 屏蔽图像建模是 BERT 屏蔽语言建模(第 07 章)的视觉模拟。大部分输入补丁被屏蔽,模型学习重建它们。

  • MAE(Masked Autoencoders,He et al.,2022)屏蔽 75% 的补丁并训练 ViT encoder-decoder 来重建丢失的像素值。 encoder 仅处理未屏蔽的补丁(在预训练期间节省 4 倍计算),轻量级 decoder 从编码的可见补丁加上可学习的掩模 tokens 重建完整图像。

  • BEiT(BERT 图像 Transformers 的预训练,Bao 等人,2022)掩盖补丁并预测离散视觉 tokens(从预先训练的 dVAE 标记器获得)而不是原始像素。这与 BERT 对离散词 tokens 的预测并行,并避免了像素重建的低级细节。

  • 图像生成旨在生成训练集中不存在的新的、真实的图像。核心挑战是对自然图像的高维概率分布进行建模。

  • 生成对抗网络 (GAN)(Goodfellow 等人,2014)使用两个竞争网络:一个从随机噪声创建假图像的 生成器 \(G\),以及一个尝试区分真实图像和假图像的 鉴别器 \(D\)。他们接受对抗性训练:\(G\) 试图愚弄 \(D\)\(D\) 试图抓住 \(G\)

\[\min_G \max_D \; \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))]\]
  • 生成器采用随机潜在向量 \(z\) (从高斯等简单分布中采样)并通过一系列转置卷积将其映射以生成图像。判别器是标准的 CNN 分类器。在平衡状态下,\(G\) 生成的图像与真实数据无法区分,而 \(D\) 对于所有输入输出 0.5。

  • 模式崩溃是 GAN 的主要失败模式:生成器只学习生成几种类型的图像来欺骗鉴别器,而忽略了训练数据的多样性。生成器找到一小组“安全”输出,而不是覆盖整个分布。

  • 稳定 GAN 的训练技巧包括:谱归一化(限制鉴别器的 Lipschitz 常数)、渐进增长(首先以低分辨率训练,然后逐渐增加)、特征匹配(匹配中间鉴别器特征的统计数据而不是最终输出)以及使用 Wasserstein 距离代替原始 JS 散度目标。

  • StyleGAN (Karras et al., 2019) 是最具影响力的 GAN 高质量图像合成架构。它的关键创新是基于样式的生成器:它不是将潜在向量 \(z\) 直接输入到生成器中,而是首先通过 映射网络(8 层 MLP)进行映射以生成样式向量 \(w\)。该样式向量通过自适应实例归一化(AdaIN)注入到生成器的每一层,它调节特征图统计数据:

\[\text{AdaIN}(x, y) = y_{s} \cdot \frac{x - \mu(x)}{\sigma(x)} + y_{b}\]
  • 其中 \(y_s\)\(y_b\) 是从 \(w\) 得出的比例和偏差。不同的图层控制不同的方面:早期图层控制粗略特征(姿势、脸部形状),中间图层控制中等特征(发型、眼睛),后期图层控制精细细节(雀斑、头发纹理)。 StyleGAN 可以生成 1024x1024 分辨率的逼真面孔。

  • 变分自动编码器(VAE)(第 06 章)提供了另一种生成方法。与 GAN 不同,VAE 具有原则性的概率框架和明确的训练目标 (ELBO)。它们往往会产生比 GAN 更模糊的图像,但提供更平滑、更结构化的潜在空间。 VAE 是潜在 diffusion 模型中使用的 encoder-decoder 对,用于将图像压缩到潜在空间或从潜在空间压缩图像。

  • 扩散模型已成为图像生成的主导范式,在质量和多样性方面都超越了 GAN。这个想法在概念上很简单:逐渐向数据添加噪声,直到它变成纯高斯噪声(前向过程),然后学习逐步反转这个过程(反向过程)。

  • 前向过程\(T\) 时间步长上添加高斯噪声:

\[q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} \, x_{t-1}, \beta_t I)\]
  • 其中 \(\beta_t\) 是随时间增加的噪声表。经过足够的步骤后,无论原始图像 \(x_0\) 为何,\(x_T\) 都近似为纯高斯噪声。使用重新参数化技巧(第 06 章)并设置 \(\alpha_t = 1 - \beta_t\)\(\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s\),我们可以直接从 \(x_0\) 采样 \(x_t\)
\[x_t = \sqrt{\bar{\alpha}_t} \, x_0 + \sqrt{1 - \bar{\alpha}_t} \, \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]
  • 反向过程学习去噪:从纯噪声 \(x_T\) 开始,模型预测每一步添加的噪声 \(\epsilon\) 并将其减去以恢复 \(x_{t-1}\)。这是由神经网络 \(\epsilon_\theta\) (通常是 U-Net,来自文件 03)进行参数化,并使用简单的 MSE 损失进行训练:
\[\mathcal{L} = \mathbb{E}_{t, x_0, \epsilon}\left[\|\epsilon - \epsilon_\theta(x_t, t)\|^2\right]\]

扩散正向和反向过程:干净的图像在T步(正向)中逐渐被噪声破坏,神经网络学习反向每一步(反向),从纯噪声开始生成干净的图像

  • DDPM(去噪扩散概率模型,Ho 等人,2020)建立了这个框架。采样需要迭代所有 \(T\) 步骤(通常为 1,000),速度很慢。 DDIM(去噪扩散隐式模型,Song 等人,2021)将采样过程重新表述为确定性映射,允许大步跳过(例如,50 个步骤而不是 1,000 个),同时质量损失最小。

  • 基于分数的模型(Song 和 Ermon,2019)提供了另一种视角。该模型不是预测噪声 \(\epsilon\),而是估计 评分函数 \(\nabla_{x_t} \log p(x_t)\),即相对于噪声图像的对数概率的梯度。该梯度指向数据分布的更高概率(更干净)的区域。使用朗之万动力学采样遵循该梯度。基于分数的模型和DDPM统一在随机微分方程(SDE)的框架中:正向过程是添加噪声的SDE,逆向过程是时间反转SDE。

  • 无分类器指导(Ho 和 Salimans,2022)控制样本质量和多样性之间的权衡。该模型经过有条件(使用文本 prompt 或类标签)和无条件(随机删除条件)的训练。在采样时,预测是加权组合:

\[\hat{\epsilon} = \epsilon_\theta(x_t, \varnothing) + s \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \varnothing))\]
  • 其中 \(c\) 是条件,\(\varnothing\) 是空条件,\(s > 1\) 是指导尺度。 \(s\) 越高,生成的图像与条件的匹配越强,但多样性越低。 \(s = 1\) 给出无引导模型; \(s = 7.5\) 是常见的默认值。

  • 潜在 diffusion (Rombach 等人,2022;稳定扩散)将 diffusion 过程从像素空间移动到学习的潜在空间。预训练的 VAE encoder 将图像压缩为低维潜在表示(通常为 4 倍或 8 倍空间下采样),diffusion 在此压缩空间中运行,而 VAE decoder 从去噪后的潜在表示中重建像素。这显着提高了效率:在像素空间中扩散 512x512 图像意味着处理 \(512 \times 512 \times 3\) 张量,但在潜在空间中仅处理 \(64 \times 64 \times 4\) 张量。

  • 潜伏 diffusion 中的去噪 U-Net 接收噪声潜伏、时间步长(编码为正弦 embedding,类似于 Transformer 中的位置编码)和条件信号(来自冻结 CLIP 的文本 embedding 或 T5 文本 encoder)。文本条件通过 U-Net 内的跨 attention 层输入:文本 embeddings 用作键和值,图像特征用作查询。这让模型能够关注文本 prompt 在每个空间位置的相关部分。

  • 流匹配是 diffusion 的一种新兴替代方案,它学习噪声和数据之间的直接传输路径,而不是 DDPM 的迭代去噪。

  • 连续归一化流 (CNF) 定义了一个与时间相关的速度场 \(v_\theta(x, t)\),它将样本沿着平滑轨迹从简单分布 \(p_0\)(噪声)推送到数据分布 \(p_1\)。该变换遵循常微分方程 (ODE):

\[\frac{dx}{dt} = v_\theta(x, t), \quad t \in [0, 1]\]
  • \(x_0 \sim \mathcal{N}(0, I)\) 开始,将 ODE 向前积分到 \(t = 1\) 会生成数据分布的样本。速度场由神经网络参数化并进行训练以匹配目标条件流。

  • 最优传输 (OT) flow matching (Lipman et al., 2023) 使用噪声和数据之间的直线路径作为目标流:从噪声样本 \(x_0\) 到数据样本 \(x_1\) 的条件路径就是 \(x_t = (1 - t) x_0 + t x_1\),目标速度是 \(v = x_1 - x_0\)。训练损失变为:

\[\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[\|v_\theta(x_t, t) - (x_1 - x_0)\|^2\right]\]
  • 修正流(Liu et al., 2022)迭代地拉直学习到的流路径。经过初始训练后,模型用于通过模拟 ODE 来生成(噪声、数据)对。这些配对比随机配对更紧密地对齐,用于重新训练模型。重复此过程会产生越来越直的路径,可以用更少的 ODE 步骤(甚至单个步骤)遍历这些路径,从而实现极快的生成。

  • 与 diffusion 相比,流匹配有几个优点:训练目标更简单(直接速度回归,无噪声调度),采样 ODE 更平滑(需要更少的积分步骤),并且与最佳传输的连接提供了理论基础。 Stable Diffusion 3 和 Flux 使用 flow matching 而不是传统的 DDPM。

编码任务(使用 CoLab 或笔记本)

  1. 从头开始实施 ViT 补丁 embedding。将图像分割成补丁,压平它们,投影到模型尺寸,添加位置 embeddings,并在前面添加 [CLS] token。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def create_patch_embedding(image, patch_size, d_model, params):
        """Convert an image into a sequence of patch embeddings."""
        H, W, C = image.shape
        n_patches_h = H // patch_size
        n_patches_w = W // patch_size
        n_patches = n_patches_h * n_patches_w
    
        # Extract patches
        patches = []
        for i in range(n_patches_h):
            for j in range(n_patches_w):
                patch = image[i*patch_size:(i+1)*patch_size,
                              j*patch_size:(j+1)*patch_size, :]
                patches.append(patch.ravel())
        patches = jnp.stack(patches)  # (N, P*P*C)
    
        # Linear projection to d_model
        embeddings = patches @ params['proj_w'] + params['proj_b']  # (N, d_model)
    
        # Prepend CLS token
        cls_token = params['cls_token']  # (1, d_model)
        embeddings = jnp.concatenate([cls_token, embeddings], axis=0)  # (N+1, d_model)
    
        # Add position embeddings
        embeddings = embeddings + params['pos_embed']  # (N+1, d_model)
    
        return embeddings, patches
    
    # Setup
    H, W, C = 32, 32, 3
    patch_size = 8
    d_model = 64
    n_patches = (H // patch_size) * (W // patch_size)  # 16
    
    key = jax.random.PRNGKey(42)
    keys = jax.random.split(key, 5)
    
    # Create a synthetic image with distinct quadrants
    image = jnp.zeros((H, W, C))
    image = image.at[:16, :16, 0].set(1.0)   # red top-left
    image = image.at[:16, 16:, 1].set(1.0)   # green top-right
    image = image.at[16:, :16, 2].set(1.0)   # blue bottom-left
    image = image.at[16:, 16:, :2].set(1.0)  # yellow bottom-right
    
    params = {
        'proj_w': jax.random.normal(keys[0], (patch_size**2 * C, d_model)) * 0.02,
        'proj_b': jnp.zeros(d_model),
        'cls_token': jax.random.normal(keys[1], (1, d_model)) * 0.02,
        'pos_embed': jax.random.normal(keys[2], (n_patches + 1, d_model)) * 0.02,
    }
    
    embeddings, patches = create_patch_embedding(image, patch_size, d_model, params)
    
    print(f"Image shape: {image.shape}")
    print(f"Patch size: {patch_size}x{patch_size}")
    print(f"Number of patches: {n_patches}")
    print(f"Patch vector length: {patch_size**2 * C}")
    print(f"Embedding shape: {embeddings.shape}  (CLS + {n_patches} patches)")
    
    # Visualise patches
    fig, axes = plt.subplots(2, 5, figsize=(14, 6))
    axes[0, 0].imshow(image); axes[0, 0].set_title('Full Image'); axes[0, 0].axis('off')
    for idx in range(min(9, n_patches)):
        ax = axes[(idx+1) // 5, (idx+1) % 5]
        patch_img = patches[idx].reshape(patch_size, patch_size, C)
        ax.imshow(patch_img); ax.set_title(f'Patch {idx}'); ax.axis('off')
    plt.suptitle('ViT Patch Decomposition')
    plt.tight_layout(); plt.show()
    

  2. 实现一个简单的 GAN 训练循环。在 2D 数据上训练生成器和判别器,并可视化生成的分布收敛到真实分布。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def generator(z, params):
        h = jnp.tanh(z @ params['g_w1'] + params['g_b1'])
        h = jnp.tanh(h @ params['g_w2'] + params['g_b2'])
        return h @ params['g_w3'] + params['g_b3']
    
    def discriminator(x, params):
        h = jax.nn.leaky_relu(x @ params['d_w1'] + params['d_b1'], 0.2)
        h = jax.nn.leaky_relu(h @ params['d_w2'] + params['d_b2'], 0.2)
        return jax.nn.sigmoid(h @ params['d_w3'] + params['d_b3'])
    
    def init_params(key):
        keys = jax.random.split(key, 6)
        z_dim, h_dim, data_dim = 2, 32, 2
        scale = 0.1
        return {
            'g_w1': jax.random.normal(keys[0], (z_dim, h_dim)) * scale,
            'g_b1': jnp.zeros(h_dim),
            'g_w2': jax.random.normal(keys[1], (h_dim, h_dim)) * scale,
            'g_b2': jnp.zeros(h_dim),
            'g_w3': jax.random.normal(keys[2], (h_dim, data_dim)) * scale,
            'g_b3': jnp.zeros(data_dim),
            'd_w1': jax.random.normal(keys[3], (data_dim, h_dim)) * scale,
            'd_b1': jnp.zeros(h_dim),
            'd_w2': jax.random.normal(keys[4], (h_dim, h_dim)) * scale,
            'd_b2': jnp.zeros(h_dim),
            'd_w3': jax.random.normal(keys[5], (h_dim, 1)) * scale,
            'd_b3': jnp.zeros(1),
        }
    
    def d_loss(params, real_data, fake_data):
        real_score = discriminator(real_data, params)
        fake_score = discriminator(fake_data, params)
        return -jnp.mean(jnp.log(real_score + 1e-7) + jnp.log(1 - fake_score + 1e-7))
    
    def g_loss(params, fake_data):
        fake_score = discriminator(fake_data, params)
        return -jnp.mean(jnp.log(fake_score + 1e-7))
    
    # Real data: ring distribution
    key = jax.random.PRNGKey(42)
    theta = jax.random.uniform(key, (512,)) * 2 * jnp.pi
    real_data = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1)
    real_data = real_data + jax.random.normal(key, real_data.shape) * 0.05
    
    params = init_params(jax.random.PRNGKey(0))
    d_grad = jax.grad(d_loss)
    g_grad = jax.grad(g_loss)
    lr = 0.001
    
    snapshots = []
    for step in range(3000):
        key, k1 = jax.random.split(key)
        z = jax.random.normal(k1, (512, 2))
        fake_data = generator(z, params)
    
        # Update discriminator
        grads = d_grad(params, real_data, fake_data)
        for k in ['d_w1', 'd_b1', 'd_w2', 'd_b2', 'd_w3', 'd_b3']:
            params[k] = params[k] - lr * grads[k]
    
        # Update generator
        fake_data = generator(z, params)
        grads = g_grad(params, fake_data)
        for k in ['g_w1', 'g_b1', 'g_w2', 'g_b2', 'g_w3', 'g_b3']:
            params[k] = params[k] - lr * grads[k]
    
        if step in [0, 500, 1500, 2999]:
            snapshots.append((step, fake_data.copy()))
    
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    for ax, (step, fake) in zip(axes, snapshots):
        ax.scatter(real_data[:, 0], real_data[:, 1], s=5, alpha=0.3, c='#3498db', label='Real')
        ax.scatter(fake[:, 0], fake[:, 1], s=5, alpha=0.3, c='#e74c3c', label='Generated')
        ax.set_title(f'Step {step}'); ax.set_xlim(-2, 2); ax.set_ylim(-2, 2)
        ax.set_aspect('equal'); ax.legend(markerscale=3)
    plt.suptitle('GAN Training: Generator Learns the Ring Distribution')
    plt.tight_layout(); plt.show()
    

  3. 实现 diffusion 前向过程:以增加的时间步长向图像添加噪声并可视化渐进的损坏。然后实施单个去噪步骤。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def noise_schedule(T, beta_start=0.0001, beta_end=0.02):
        """Linear noise schedule."""
        betas = jnp.linspace(beta_start, beta_end, T)
        alphas = 1.0 - betas
        alpha_bars = jnp.cumprod(alphas)
        return betas, alphas, alpha_bars
    
    def forward_diffusion(x0, t, alpha_bars, key):
        """Add noise to x0 at timestep t."""
        alpha_bar_t = alpha_bars[t]
        noise = jax.random.normal(key, x0.shape)
        xt = jnp.sqrt(alpha_bar_t) * x0 + jnp.sqrt(1 - alpha_bar_t) * noise
        return xt, noise
    
    # Create a simple 2D "image" (checkerboard)
    img = jnp.zeros((32, 32))
    for i in range(4):
        for j in range(4):
            if (i + j) % 2 == 0:
                img = img.at[i*8:(i+1)*8, j*8:(j+1)*8].set(1.0)
    
    T = 1000
    betas, alphas, alpha_bars = noise_schedule(T)
    
    # Visualise forward process
    timesteps = [0, 50, 200, 500, 999]
    key = jax.random.PRNGKey(42)
    
    fig, axes = plt.subplots(1, len(timesteps), figsize=(16, 3.5))
    for ax, t in zip(axes, timesteps):
        key, subkey = jax.random.split(key)
        xt, noise = forward_diffusion(img, t, alpha_bars, subkey)
        ax.imshow(xt, cmap='gray', vmin=-2, vmax=2)
        ax.set_title(f't={t}\n$\\bar{{\\alpha}}$={alpha_bars[t]:.3f}')
        ax.axis('off')
    plt.suptitle('Diffusion Forward Process: Progressive Noise Addition')
    plt.tight_layout(); plt.show()
    
    # Simple denoising: train a tiny network to predict noise at t=200
    t_denoise = 200
    key, k1 = jax.random.split(key)
    xt, true_noise = forward_diffusion(img, t_denoise, alpha_bars, k1)
    
    # Tiny "denoiser": just learn a constant noise estimate (for illustration)
    noise_estimate = jnp.zeros_like(img)
    lr = 0.01
    for step in range(100):
        residual = noise_estimate - true_noise
        noise_estimate = noise_estimate - lr * residual
    
    # Reverse one step
    alpha_bar_t = alpha_bars[t_denoise]
    x_denoised = (xt - jnp.sqrt(1 - alpha_bar_t) * noise_estimate) / jnp.sqrt(alpha_bar_t)
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(img, cmap='gray'); axes[0].set_title('Original $x_0$'); axes[0].axis('off')
    axes[1].imshow(xt, cmap='gray', vmin=-2, vmax=2)
    axes[1].set_title(f'Noisy $x_{{200}}$'); axes[1].axis('off')
    axes[2].imshow(x_denoised, cmap='gray')
    axes[2].set_title('Denoised (one step)'); axes[2].axis('off')
    plt.tight_layout(); plt.show()
    
    mse = jnp.mean((x_denoised - img)**2)
    print(f"Denoising MSE: {mse:.4f}")