Vision Transformers 与生成¶
视觉转换器将 self-attention 应用于图像块,通过数据驱动的空间学习挑战 CNN 的主导地位。该文件涵盖 ViT、DeiT、Swin Transformer、使用 GAN (StyleGAN)、VAE 和 diffusion 模型(DDPM,稳定扩散)生成图像,以及超分辨率和神经风格迁移。
-
CNN(文件 02)构建了强大的空间归纳偏差:局部连接、权重共享和平移等变性。 Vision Transformers (ViTs) 提出了一个挑衅性的问题:如果我们完全放弃这些偏差并让模型仅使用第 06 章中的 attention 机制从数据中学习空间结构会怎样?
-
Vision Transformer (ViT)(Dosovitskiy 等人,2021)将标准 Transformer encoder 直接应用于图像。关键思想是将图像视为补丁序列,就像 NLP 将文本视为 tokens 序列一样。
-
该过程的工作原理如下:
- 将图像(高度 \(H\)、宽度 \(W\)、通道 \(C\))分割为大小为 \(P \times P\) 的非重叠补丁网格。这会产生 \(N = HW / P^2\) 补丁。
- 将每个补丁展平为长度为 \(P^2 \cdot C\) 的向量,并通过学习的线性 embedding (单个矩阵乘法,第 02 章)将其投影到模型维度 \(D\) 。
- 前置一个可学习的 [CLS] token embedding (类似于 BERT 的 [CLS],第 07 章)。该 token 关注所有补丁,其最终表示用于分类。
- 添加 位置 embeddings (每个位置一个可学习向量)以提供空间信息,因为 attention 是排列等变的。
- 将 \((N + 1)\) token embeddings 的序列传递给标准 Transformer encoder (multi-head self-attention + FFN,第 06 章)。
- [CLS] token 的最终表示通过分类头(一个小的 MLP)。
-
补丁 embedding 相当于内核大小 \(P\) 和步幅 \(P\) (非重叠)的卷积。 ViT 从字面上将 2D 图像转换为 1D 序列,然后使用与语言相同的架构对其进行处理。
-
ViT 比 CNN 具有更少的归纳偏差:它不强制执行局部连接或平移等变性。这意味着它需要更多的训练数据来从头开始学习空间结构。在小型数据集上,CNN 的性能优于 ViT。但是,当在非常大的数据集(JFT-300M,3 亿张图像)上进行训练时,ViT 匹配或超过了最好的 CNN,这表明 CNN 的归纳偏差有助于提高数据效率,但对于最终性能而言并不是必需的。
-
ViT self-attention 是补丁数 \(O(N^2)\)。对于具有 16x16 补丁的 224x224 图像,\(N = 196\),这是可以管理的。但对于更高分辨率的图像或更小的补丁,二次成本变得令人望而却步。
-
DeiT(数据高效图像 Transformer,Touvron 等人,2021)表明,使用强大的数据增强、正则化(随机深度、标签平滑、丢弃)和知识蒸馏,可以仅在 ImageNet 上有效地训练 ViT(无需大量 JFT 数据集):预先训练的 CNN 教师提供 ViT 学生学习匹配的软标签。 DeiT 在 [CLS] token 旁边添加了 蒸馏 token,经过训练以预测教师的输出。
-
Swin Transformer (Liu et al., 2021) 解决了 ViT 的两个主要局限性:其随图像大小二次增长的计算成本以及缺乏分层特征图(检测和分割需要)。
-
Swin 引入 移位窗口: attention 是在本地窗口(例如 7x7 补丁)内计算的,而不是所有补丁上的全局 self-attention 。这使得成本与图像大小呈线性关系:\(O(N)\) 而不是 \(O(N^2)\)。但仅本地窗口就会阻止区域之间的信息流动。
-
窗口移动解决了这个问题:在交替层中,窗口分区移动窗口大小的一半。这会创建跨窗口连接,允许信息跨层在图像的所有部分之间流动,而无需全局 attention 的成本。
-
Swin 还通过跨阶段合并补丁来构建分层表示。在每个阶段之后,相邻的 2x2 块被连接并投影,以使通道尺寸加倍并将空间分辨率减半。这会产生类似于 CNN 和 FPN(文件 03)中的多尺度特征图,使 Swin 直接与 Faster R-CNN 等检测头和 U-Net 等分割头兼容。
-
PVT (Pyramid Vision Transformer) 采用类似的具有空间缩减 attention 的分层方法:在每个阶段,在计算 attention 之前对键和值进行空间下采样,减少二次成本,同时保持全局感受野。
-
自监督视觉学习训练未标记图像的表示。收集标签的成本很高,但图像却很丰富。目标是学习无需任何人工注释即可很好地转移到下游任务的功能。
-
对比学习训练模型认识到同一图像的两个增强视图(“正对”)应具有相似的表示,而不同图像的视图(“负对”)应具有不同的表示。
-
SimCLR(Chen 等人,2020)批量创建每个图像的两个增强视图,使用共享主干 + 投影头对两者进行编码,并应用 NT-Xent 损失(归一化温度缩放交叉熵):
-
其中 \(\text{sim}\) 是余弦相似度(第 01 章),\(\tau\) 是温度参数。分子将正对推到一起;分母将负数对分开。 SimCLR 需要大批量(4,096+)才能提供足够的负样本。
-
MoCo(Momentum Contrast,He et al.,2020)通过维护负 embeddings 的动量更新队列来解决大批量要求。查询 encoder 通过梯度下降进行更新;键 encoder 更新为查询 encoder 的指数移动平均线(EMA,第 04 章):\(\theta_k \leftarrow m \theta_k + (1 - m) \theta_q\),带有 \(m = 0.999\)。队列存储最近的密钥 embeddings,提供大量且一致的负数集,而无需大量批量。
-
BYOL(Bootstrap Your Own Latent,Grill et al.,2020)完全消除负对。它使用两个网络:“在线”网络和“目标”网络(在线的 EMA)。在线网络预测目标网络对不同增强视图的表示。在没有负数的情况下,BYOL 通过预测器头部和 EMA 目标的不对称性避免了崩溃问题(模型为所有内容输出相同的向量)。
-
DINO(无标签自蒸馏,Caron 等人,2021)将自蒸馏应用于 ViT。学生网络预测教师网络在不同增强视图上的输出(学生的 EMA)。老师使用较大的作物;学生使用较小的作物。 DINO 生成包含有关场景布局的显式信息的特征:经过 DINO 训练的 ViT 的 self-attention 地图自然地分割对象,无需任何分割监督。
-
屏蔽图像建模是 BERT 屏蔽语言建模(第 07 章)的视觉模拟。大部分输入补丁被屏蔽,模型学习重建它们。
-
MAE(Masked Autoencoders,He et al.,2022)屏蔽 75% 的补丁并训练 ViT encoder-decoder 来重建丢失的像素值。 encoder 仅处理未屏蔽的补丁(在预训练期间节省 4 倍计算),轻量级 decoder 从编码的可见补丁加上可学习的掩模 tokens 重建完整图像。
-
BEiT(BERT 图像 Transformers 的预训练,Bao 等人,2022)掩盖补丁并预测离散视觉 tokens(从预先训练的 dVAE 标记器获得)而不是原始像素。这与 BERT 对离散词 tokens 的预测并行,并避免了像素重建的低级细节。
-
图像生成旨在生成训练集中不存在的新的、真实的图像。核心挑战是对自然图像的高维概率分布进行建模。
-
生成对抗网络 (GAN)(Goodfellow 等人,2014)使用两个竞争网络:一个从随机噪声创建假图像的 生成器 \(G\),以及一个尝试区分真实图像和假图像的 鉴别器 \(D\)。他们接受对抗性训练:\(G\) 试图愚弄 \(D\),\(D\) 试图抓住 \(G\)。
-
生成器采用随机潜在向量 \(z\) (从高斯等简单分布中采样)并通过一系列转置卷积将其映射以生成图像。判别器是标准的 CNN 分类器。在平衡状态下,\(G\) 生成的图像与真实数据无法区分,而 \(D\) 对于所有输入输出 0.5。
-
模式崩溃是 GAN 的主要失败模式:生成器只学习生成几种类型的图像来欺骗鉴别器,而忽略了训练数据的多样性。生成器找到一小组“安全”输出,而不是覆盖整个分布。
-
稳定 GAN 的训练技巧包括:谱归一化(限制鉴别器的 Lipschitz 常数)、渐进增长(首先以低分辨率训练,然后逐渐增加)、特征匹配(匹配中间鉴别器特征的统计数据而不是最终输出)以及使用 Wasserstein 距离代替原始 JS 散度目标。
-
StyleGAN (Karras et al., 2019) 是最具影响力的 GAN 高质量图像合成架构。它的关键创新是基于样式的生成器:它不是将潜在向量 \(z\) 直接输入到生成器中,而是首先通过 映射网络(8 层 MLP)进行映射以生成样式向量 \(w\)。该样式向量通过自适应实例归一化(AdaIN)注入到生成器的每一层,它调节特征图统计数据:
-
其中 \(y_s\) 和 \(y_b\) 是从 \(w\) 得出的比例和偏差。不同的图层控制不同的方面:早期图层控制粗略特征(姿势、脸部形状),中间图层控制中等特征(发型、眼睛),后期图层控制精细细节(雀斑、头发纹理)。 StyleGAN 可以生成 1024x1024 分辨率的逼真面孔。
-
变分自动编码器(VAE)(第 06 章)提供了另一种生成方法。与 GAN 不同,VAE 具有原则性的概率框架和明确的训练目标 (ELBO)。它们往往会产生比 GAN 更模糊的图像,但提供更平滑、更结构化的潜在空间。 VAE 是潜在 diffusion 模型中使用的 encoder-decoder 对,用于将图像压缩到潜在空间或从潜在空间压缩图像。
-
扩散模型已成为图像生成的主导范式,在质量和多样性方面都超越了 GAN。这个想法在概念上很简单:逐渐向数据添加噪声,直到它变成纯高斯噪声(前向过程),然后学习逐步反转这个过程(反向过程)。
-
前向过程在 \(T\) 时间步长上添加高斯噪声:
- 其中 \(\beta_t\) 是随时间增加的噪声表。经过足够的步骤后,无论原始图像 \(x_0\) 为何,\(x_T\) 都近似为纯高斯噪声。使用重新参数化技巧(第 06 章)并设置 \(\alpha_t = 1 - \beta_t\)、\(\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s\),我们可以直接从 \(x_0\) 采样 \(x_t\):
- 反向过程学习去噪:从纯噪声 \(x_T\) 开始,模型预测每一步添加的噪声 \(\epsilon\) 并将其减去以恢复 \(x_{t-1}\)。这是由神经网络 \(\epsilon_\theta\) (通常是 U-Net,来自文件 03)进行参数化,并使用简单的 MSE 损失进行训练:
-
DDPM(去噪扩散概率模型,Ho 等人,2020)建立了这个框架。采样需要迭代所有 \(T\) 步骤(通常为 1,000),速度很慢。 DDIM(去噪扩散隐式模型,Song 等人,2021)将采样过程重新表述为确定性映射,允许大步跳过(例如,50 个步骤而不是 1,000 个),同时质量损失最小。
-
基于分数的模型(Song 和 Ermon,2019)提供了另一种视角。该模型不是预测噪声 \(\epsilon\),而是估计 评分函数 \(\nabla_{x_t} \log p(x_t)\),即相对于噪声图像的对数概率的梯度。该梯度指向数据分布的更高概率(更干净)的区域。使用朗之万动力学采样遵循该梯度。基于分数的模型和DDPM统一在随机微分方程(SDE)的框架中:正向过程是添加噪声的SDE,逆向过程是时间反转SDE。
-
无分类器指导(Ho 和 Salimans,2022)控制样本质量和多样性之间的权衡。该模型经过有条件(使用文本 prompt 或类标签)和无条件(随机删除条件)的训练。在采样时,预测是加权组合:
-
其中 \(c\) 是条件,\(\varnothing\) 是空条件,\(s > 1\) 是指导尺度。 \(s\) 越高,生成的图像与条件的匹配越强,但多样性越低。 \(s = 1\) 给出无引导模型; \(s = 7.5\) 是常见的默认值。
-
潜在 diffusion (Rombach 等人,2022;稳定扩散)将 diffusion 过程从像素空间移动到学习的潜在空间。预训练的 VAE encoder 将图像压缩为低维潜在表示(通常为 4 倍或 8 倍空间下采样),diffusion 在此压缩空间中运行,而 VAE decoder 从去噪后的潜在表示中重建像素。这显着提高了效率:在像素空间中扩散 512x512 图像意味着处理 \(512 \times 512 \times 3\) 张量,但在潜在空间中仅处理 \(64 \times 64 \times 4\) 张量。
-
潜伏 diffusion 中的去噪 U-Net 接收噪声潜伏、时间步长(编码为正弦 embedding,类似于 Transformer 中的位置编码)和条件信号(来自冻结 CLIP 的文本 embedding 或 T5 文本 encoder)。文本条件通过 U-Net 内的跨 attention 层输入:文本 embeddings 用作键和值,图像特征用作查询。这让模型能够关注文本 prompt 在每个空间位置的相关部分。
-
流匹配是 diffusion 的一种新兴替代方案,它学习噪声和数据之间的直接传输路径,而不是 DDPM 的迭代去噪。
-
连续归一化流 (CNF) 定义了一个与时间相关的速度场 \(v_\theta(x, t)\),它将样本沿着平滑轨迹从简单分布 \(p_0\)(噪声)推送到数据分布 \(p_1\)。该变换遵循常微分方程 (ODE):
-
从 \(x_0 \sim \mathcal{N}(0, I)\) 开始,将 ODE 向前积分到 \(t = 1\) 会生成数据分布的样本。速度场由神经网络参数化并进行训练以匹配目标条件流。
-
最优传输 (OT) flow matching (Lipman et al., 2023) 使用噪声和数据之间的直线路径作为目标流:从噪声样本 \(x_0\) 到数据样本 \(x_1\) 的条件路径就是 \(x_t = (1 - t) x_0 + t x_1\),目标速度是 \(v = x_1 - x_0\)。训练损失变为:
-
修正流(Liu et al., 2022)迭代地拉直学习到的流路径。经过初始训练后,模型用于通过模拟 ODE 来生成(噪声、数据)对。这些配对比随机配对更紧密地对齐,用于重新训练模型。重复此过程会产生越来越直的路径,可以用更少的 ODE 步骤(甚至单个步骤)遍历这些路径,从而实现极快的生成。
-
与 diffusion 相比,流匹配有几个优点:训练目标更简单(直接速度回归,无噪声调度),采样 ODE 更平滑(需要更少的积分步骤),并且与最佳传输的连接提供了理论基础。 Stable Diffusion 3 和 Flux 使用 flow matching 而不是传统的 DDPM。
编码任务(使用 CoLab 或笔记本)¶
-
从头开始实施 ViT 补丁 embedding。将图像分割成补丁,压平它们,投影到模型尺寸,添加位置 embeddings,并在前面添加 [CLS] token。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt def create_patch_embedding(image, patch_size, d_model, params): """Convert an image into a sequence of patch embeddings.""" H, W, C = image.shape n_patches_h = H // patch_size n_patches_w = W // patch_size n_patches = n_patches_h * n_patches_w # Extract patches patches = [] for i in range(n_patches_h): for j in range(n_patches_w): patch = image[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size, :] patches.append(patch.ravel()) patches = jnp.stack(patches) # (N, P*P*C) # Linear projection to d_model embeddings = patches @ params['proj_w'] + params['proj_b'] # (N, d_model) # Prepend CLS token cls_token = params['cls_token'] # (1, d_model) embeddings = jnp.concatenate([cls_token, embeddings], axis=0) # (N+1, d_model) # Add position embeddings embeddings = embeddings + params['pos_embed'] # (N+1, d_model) return embeddings, patches # Setup H, W, C = 32, 32, 3 patch_size = 8 d_model = 64 n_patches = (H // patch_size) * (W // patch_size) # 16 key = jax.random.PRNGKey(42) keys = jax.random.split(key, 5) # Create a synthetic image with distinct quadrants image = jnp.zeros((H, W, C)) image = image.at[:16, :16, 0].set(1.0) # red top-left image = image.at[:16, 16:, 1].set(1.0) # green top-right image = image.at[16:, :16, 2].set(1.0) # blue bottom-left image = image.at[16:, 16:, :2].set(1.0) # yellow bottom-right params = { 'proj_w': jax.random.normal(keys[0], (patch_size**2 * C, d_model)) * 0.02, 'proj_b': jnp.zeros(d_model), 'cls_token': jax.random.normal(keys[1], (1, d_model)) * 0.02, 'pos_embed': jax.random.normal(keys[2], (n_patches + 1, d_model)) * 0.02, } embeddings, patches = create_patch_embedding(image, patch_size, d_model, params) print(f"Image shape: {image.shape}") print(f"Patch size: {patch_size}x{patch_size}") print(f"Number of patches: {n_patches}") print(f"Patch vector length: {patch_size**2 * C}") print(f"Embedding shape: {embeddings.shape} (CLS + {n_patches} patches)") # Visualise patches fig, axes = plt.subplots(2, 5, figsize=(14, 6)) axes[0, 0].imshow(image); axes[0, 0].set_title('Full Image'); axes[0, 0].axis('off') for idx in range(min(9, n_patches)): ax = axes[(idx+1) // 5, (idx+1) % 5] patch_img = patches[idx].reshape(patch_size, patch_size, C) ax.imshow(patch_img); ax.set_title(f'Patch {idx}'); ax.axis('off') plt.suptitle('ViT Patch Decomposition') plt.tight_layout(); plt.show() -
实现一个简单的 GAN 训练循环。在 2D 数据上训练生成器和判别器,并可视化生成的分布收敛到真实分布。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt def generator(z, params): h = jnp.tanh(z @ params['g_w1'] + params['g_b1']) h = jnp.tanh(h @ params['g_w2'] + params['g_b2']) return h @ params['g_w3'] + params['g_b3'] def discriminator(x, params): h = jax.nn.leaky_relu(x @ params['d_w1'] + params['d_b1'], 0.2) h = jax.nn.leaky_relu(h @ params['d_w2'] + params['d_b2'], 0.2) return jax.nn.sigmoid(h @ params['d_w3'] + params['d_b3']) def init_params(key): keys = jax.random.split(key, 6) z_dim, h_dim, data_dim = 2, 32, 2 scale = 0.1 return { 'g_w1': jax.random.normal(keys[0], (z_dim, h_dim)) * scale, 'g_b1': jnp.zeros(h_dim), 'g_w2': jax.random.normal(keys[1], (h_dim, h_dim)) * scale, 'g_b2': jnp.zeros(h_dim), 'g_w3': jax.random.normal(keys[2], (h_dim, data_dim)) * scale, 'g_b3': jnp.zeros(data_dim), 'd_w1': jax.random.normal(keys[3], (data_dim, h_dim)) * scale, 'd_b1': jnp.zeros(h_dim), 'd_w2': jax.random.normal(keys[4], (h_dim, h_dim)) * scale, 'd_b2': jnp.zeros(h_dim), 'd_w3': jax.random.normal(keys[5], (h_dim, 1)) * scale, 'd_b3': jnp.zeros(1), } def d_loss(params, real_data, fake_data): real_score = discriminator(real_data, params) fake_score = discriminator(fake_data, params) return -jnp.mean(jnp.log(real_score + 1e-7) + jnp.log(1 - fake_score + 1e-7)) def g_loss(params, fake_data): fake_score = discriminator(fake_data, params) return -jnp.mean(jnp.log(fake_score + 1e-7)) # Real data: ring distribution key = jax.random.PRNGKey(42) theta = jax.random.uniform(key, (512,)) * 2 * jnp.pi real_data = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1) real_data = real_data + jax.random.normal(key, real_data.shape) * 0.05 params = init_params(jax.random.PRNGKey(0)) d_grad = jax.grad(d_loss) g_grad = jax.grad(g_loss) lr = 0.001 snapshots = [] for step in range(3000): key, k1 = jax.random.split(key) z = jax.random.normal(k1, (512, 2)) fake_data = generator(z, params) # Update discriminator grads = d_grad(params, real_data, fake_data) for k in ['d_w1', 'd_b1', 'd_w2', 'd_b2', 'd_w3', 'd_b3']: params[k] = params[k] - lr * grads[k] # Update generator fake_data = generator(z, params) grads = g_grad(params, fake_data) for k in ['g_w1', 'g_b1', 'g_w2', 'g_b2', 'g_w3', 'g_b3']: params[k] = params[k] - lr * grads[k] if step in [0, 500, 1500, 2999]: snapshots.append((step, fake_data.copy())) fig, axes = plt.subplots(1, 4, figsize=(16, 4)) for ax, (step, fake) in zip(axes, snapshots): ax.scatter(real_data[:, 0], real_data[:, 1], s=5, alpha=0.3, c='#3498db', label='Real') ax.scatter(fake[:, 0], fake[:, 1], s=5, alpha=0.3, c='#e74c3c', label='Generated') ax.set_title(f'Step {step}'); ax.set_xlim(-2, 2); ax.set_ylim(-2, 2) ax.set_aspect('equal'); ax.legend(markerscale=3) plt.suptitle('GAN Training: Generator Learns the Ring Distribution') plt.tight_layout(); plt.show() -
实现 diffusion 前向过程:以增加的时间步长向图像添加噪声并可视化渐进的损坏。然后实施单个去噪步骤。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt def noise_schedule(T, beta_start=0.0001, beta_end=0.02): """Linear noise schedule.""" betas = jnp.linspace(beta_start, beta_end, T) alphas = 1.0 - betas alpha_bars = jnp.cumprod(alphas) return betas, alphas, alpha_bars def forward_diffusion(x0, t, alpha_bars, key): """Add noise to x0 at timestep t.""" alpha_bar_t = alpha_bars[t] noise = jax.random.normal(key, x0.shape) xt = jnp.sqrt(alpha_bar_t) * x0 + jnp.sqrt(1 - alpha_bar_t) * noise return xt, noise # Create a simple 2D "image" (checkerboard) img = jnp.zeros((32, 32)) for i in range(4): for j in range(4): if (i + j) % 2 == 0: img = img.at[i*8:(i+1)*8, j*8:(j+1)*8].set(1.0) T = 1000 betas, alphas, alpha_bars = noise_schedule(T) # Visualise forward process timesteps = [0, 50, 200, 500, 999] key = jax.random.PRNGKey(42) fig, axes = plt.subplots(1, len(timesteps), figsize=(16, 3.5)) for ax, t in zip(axes, timesteps): key, subkey = jax.random.split(key) xt, noise = forward_diffusion(img, t, alpha_bars, subkey) ax.imshow(xt, cmap='gray', vmin=-2, vmax=2) ax.set_title(f't={t}\n$\\bar{{\\alpha}}$={alpha_bars[t]:.3f}') ax.axis('off') plt.suptitle('Diffusion Forward Process: Progressive Noise Addition') plt.tight_layout(); plt.show() # Simple denoising: train a tiny network to predict noise at t=200 t_denoise = 200 key, k1 = jax.random.split(key) xt, true_noise = forward_diffusion(img, t_denoise, alpha_bars, k1) # Tiny "denoiser": just learn a constant noise estimate (for illustration) noise_estimate = jnp.zeros_like(img) lr = 0.01 for step in range(100): residual = noise_estimate - true_noise noise_estimate = noise_estimate - lr * residual # Reverse one step alpha_bar_t = alpha_bars[t_denoise] x_denoised = (xt - jnp.sqrt(1 - alpha_bar_t) * noise_estimate) / jnp.sqrt(alpha_bar_t) fig, axes = plt.subplots(1, 3, figsize=(12, 4)) axes[0].imshow(img, cmap='gray'); axes[0].set_title('Original $x_0$'); axes[0].axis('off') axes[1].imshow(xt, cmap='gray', vmin=-2, vmax=2) axes[1].set_title(f'Noisy $x_{{200}}$'); axes[1].axis('off') axes[2].imshow(x_denoised, cmap='gray') axes[2].set_title('Denoised (one step)'); axes[2].axis('off') plt.tight_layout(); plt.show() mse = jnp.mean((x_denoised - img)**2) print(f"Denoising MSE: {mse:.4f}")