Skip to content

信息论

信息论对信息、惊讶程度以及概率分布之间的差异进行量化。本文涵盖 entropy、cross-entropy、KL divergence、mutual information 和 surprisal——这些概念是 ML 中每个分类损失函数、VAE 目标函数以及数据压缩方案的基础。

  • 信息论由 Claude Shannon 于 1948 年创立,为量化信息提供了数学框架。它回答了以下问题:对某个事件你应该有多惊讶?一条消息携带多少信息?两个概率分布有多大差异?

  • 这些问题听起来抽象,但它们是 ML 损失函数、数据压缩和通信系统的基础。分类中最常用的损失函数 cross-entropy 损失,直接来源于信息论。

  • 从最简单的问题开始:单个事件携带多少信息?

  • Surprisal(也称为自信息)衡量事件的惊讶程度。如果某件很可能发生的事情发生了,你几乎学不到什么。如果某件罕见的事情发生了,你会学到很多。

  • 如果你住在沙漠里,有人告诉你今天是晴天,这并不是很有信息量。如果他们告诉你在下雪,那就非常有信息量了。Surprisal 将这种直觉形式化:

\[I(x) = \log_2 \frac{1}{p(x)} = -\log_2 p(x)\]
  • 当使用 \(\log_2\) 时,单位是比特。公平硬币的 surprisal 为 \(-\log_2(0.5) = 1\) 比特。概率为 \(1/8\) 的事件的 surprisal 为 \(\log_2(8) = 3\) 比特。

  • 为什么用对数而不是 \(1/p\)?三个原因:

    • 必然事件(\(p = 1\))应该给出零信息:\(\log(1) = 0\),但 \(1/1 = 1\)
    • 独立事件的信息应该相加:\(\log(1/p_1 p_2) = \log(1/p_1) + \log(1/p_2)\)
    • 我们想要一个平滑、性质良好的函数。\(1/p\) 会爆炸;\(\log(1/p)\) 则平缓增长。
  • Entropy 是期望 surprisal,即从分布中采样每个事件时平均获得的信息量。它衡量分布的不确定性或"不可预测性":

\[H(X) = E[I(X)] = -\sum_{x} p(x) \log_2 p(x)\]

条形图展示高概率事件的 surprisal 低,反之亦然;entropy 是加权平均

  • 公平硬币的 entropy 为 \(H = -0.5\log_2(0.5) - 0.5\log_2(0.5) = 1\) 比特,不确定性最大。

  • 偏置为 \(p = 0.9\) 的硬币的 entropy 为 \(H = -0.9\log_2(0.9) - 0.1\log_2(0.1) \approx 0.469\) 比特。不确定性更小,entropy 也更小。

  • 确定性事件(\(p = 1\))的 entropy 为 \(H = 0\),没有任何不确定性。

  • 当所有结果等可能时,entropy 最大。对于 \(n\) 个等可能结果,\(H = \log_2 n\)。公平骰子的 entropy 为 \(\log_2 6 \approx 2.585\) 比特。

  • Entropy 的实际意义是压缩。Shannon 的信源编码定理指出,在不丢失信息的前提下,无法将数据压缩到其 entropy 率以下。每个像素都等可能出现的图像(最大 entropy)无法被压缩;大部分为白色的图像(低 entropy)压缩效果很好。

  • 量级感受:灰度像素(256 个值)的最大 entropy 为 8 比特。一张 1080p 灰度图像最多有 \(1920 \times 1080 \times 8 \approx 1660\) 万比特。真实图像的 entropy 要低得多,因为相邻像素是相关的,这就是 JPEG 压缩有效的原因。

  • 对于连续随机变量,离散求和变为积分。微分 entropy 为:

\[h(X) = -\int_{-\infty}^{\infty} f(x) \log f(x)\, dx\]
  • 方差为 \(\sigma^2\) 的 Gaussian 的微分 entropy 为 \(h = \frac{1}{2}\log_2(2\pi e \sigma^2)\)。在所有具有相同方差的分布中,Gaussian 具有最大 entropy。这是 Gaussian 在建模中如此常见的原因之一:它在指定的均值和方差之外做出最少的假设。

  • Mutual information 衡量知道一个变量能告诉你多少关于另一个变量的信息。它是观察 \(Y\)\(X\) 的不确定性的减少量:

\[I(X; Y) = H(X) - H(X|Y) = H(Y) - H(Y|X)\]
  • 等价地:
\[I(X; Y) = \sum_{x,y} p(x,y) \log_2 \frac{p(x,y)}{p(x) p(y)}\]
  • 如果 \(X\)\(Y\) 独立,\(p(x,y) = p(x)p(y)\),mutual information 为零。它们越依赖,mutual information 越高。

  • 在 ML 中,mutual information 用于特征选择(选择与目标具有高 mutual information 的特征)、information bottleneck 方法以及评估聚类质量。

  • Cross-entropy 衡量使用针对分布 \(q\) 优化的编码来编码来自分布 \(p\) 的事件所需的平均比特数:

\[H(p, q) = -\sum_{x} p(x) \log_2 q(x)\]
  • 如果 \(q\)\(p\) 完全匹配,cross-entropy 等于 entropy:\(H(p, p) = H(p)\)。如果 \(q\) 是一个差劲的近似,cross-entropy 会更高。"额外"的比特来自不匹配。

  • 这正是 cross-entropy 成为 ML 分类标准损失函数的原因。真实标签定义 \(p\)(one-hot 分布),模型预测的概率定义 \(q\)。最小化 cross-entropy 将 \(q\) 推向 \(p\)

\[\mathcal{L} = -\sum_{c} y_c \log \hat{y}_c\]
  • 对于真实类别为 \(c\) 的单个样本,这简化为 \(\mathcal{L} = -\log \hat{y}_c\)。损失是模型预测下真实类别的 surprisal。如果模型对正确类别赋予高概率,损失就低。

  • KL divergence(Kullback-Leibler 散度,也称为相对 entropy)衡量一个分布与另一个分布的差异程度:

\[D_{\text{KL}}(p \| q) = \sum_{x} p(x) \log \frac{p(x)}{q(x)} = H(p, q) - H(p)\]
  • KL divergence 是使用分布 \(q\) 代替真实分布 \(p\) 的"额外代价"。它始终是非负的(\(D_{\text{KL}} \ge 0\)),只有当 \(p = q\) 时才等于零。

两个分布 p 和 q,它们之间的差距表示 KL divergence

  • KL divergence 是不对称的:\(D_{\text{KL}}(p \| q) \ne D_{\text{KL}}(q \| p)\)。这种不对称性很重要。\(D_{\text{KL}}(p \| q)\) 惩罚 \(q\)\(p\) 概率高的地方赋予低概率(因为 \(\log(p/q)\) 会变得很大)。\(D_{\text{KL}}(q \| p)\) 则惩罚相反的情况。

  • 这种不对称性导致两种近似风格:

    • 最小化 \(D_{\text{KL}}(p \| q)\) 产生矩匹配行为:\(q\) 覆盖 \(p\) 的所有众数,但可能过于分散。
    • 最小化 \(D_{\text{KL}}(q \| p)\) 产生众数寻找行为:\(q\) 集中于 \(p\) 的一个众数,但可能错过其他众数。这是变分推断所使用的方式。
  • 由于 \(H(p)\) 相对于模型是常数,最小化 cross-entropy \(H(p, q)\) 等价于最小化 \(D_{\text{KL}}(p \| q)\)。这就是为什么我们可以使用 cross-entropy 损失,同时知道我们也在最小化真实分布和预测分布之间的 KL divergence。

  • KL divergence 在贝叶斯更新中起着核心作用。Posterior \(P(\theta | D)\) 是在 KL divergence 意义上最接近 prior \(P(\theta)\) 且与观察数据一致的分布。每次新的观察都会更新 posterior,减少对 \(\theta\) 的不确定性。

  • 在变分自编码器(VAE)中,损失函数有两项:重建损失(cross-entropy)和 KL divergence 项,后者将隐空间正则化,使其保持接近标准正态分布。

  • 综合来看:entropy 告诉你分布中固有的不确定性,cross-entropy 告诉你模型对现实的近似程度,KL divergence 告诉你两者之间的差距。这三个量构成了现代 ML 优化的核心。

编程任务(使用 CoLab 或 notebook)

  1. 计算各种分布的 entropy,并验证均匀分布对于给定结果数具有最大 entropy。

    import jax.numpy as jnp
    
    def entropy(p):
        """计算以比特为单位的 entropy,过滤掉零概率事件。"""
        p = p[p > 0]
        return -jnp.sum(p * jnp.log2(p))
    
    # 公平骰子
    fair = jnp.ones(6) / 6
    print(f"公平骰子 entropy:   {entropy(fair):.4f} 比特(最大值 = log2(6) = {jnp.log2(6.):.4f})")
    
    # 有偏骰子
    loaded = jnp.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.5])
    print(f"有偏骰子 entropy: {entropy(loaded):.4f} 比特")
    
    # 确定性
    det = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 1.0])
    print(f"确定性分布:      {entropy(det):.4f} 比特")
    
    # 公平硬币
    coin = jnp.array([0.5, 0.5])
    print(f"公平硬币 entropy:  {entropy(coin):.4f} 比特")
    

  2. 计算真实分布与若干近似分布之间的 cross-entropy 和 KL divergence。验证 \(D_{\text{KL}}(p \| q) = H(p, q) - H(p)\)

    import jax.numpy as jnp
    
    def cross_entropy(p, q):
        return -jnp.sum(p * jnp.log2(jnp.clip(q, 1e-10, 1.0)))
    
    def kl_divergence(p, q):
        mask = p > 0
        return jnp.sum(jnp.where(mask, p * jnp.log2(p / jnp.clip(q, 1e-10, 1.0)), 0.0))
    
    def entropy(p):
        p = p[p > 0]
        return -jnp.sum(p * jnp.log2(p))
    
    p = jnp.array([0.4, 0.3, 0.2, 0.1])  # 真实分布
    
    for name, q in [("完全匹配", p),
                    ("轻微不匹配", jnp.array([0.35, 0.30, 0.25, 0.10])),
                    ("严重不匹配", jnp.array([0.1, 0.1, 0.1, 0.7]))]:
        h_p = entropy(p)
        h_pq = cross_entropy(p, q)
        kl = kl_divergence(p, q)
        print(f"{name:12s}: H(p)={h_p:.4f}, H(p,q)={h_pq:.4f}, "
              f"KL={kl:.4f}, H(p,q)-H(p)={h_pq-h_p:.4f}")
    

  3. 对两个不同的分布计算 \(D_{\text{KL}}(p \| q)\)\(D_{\text{KL}}(q \| p)\),展示 KL divergence 的不对称性。

    import jax.numpy as jnp
    
    def kl_div(p, q):
        mask = p > 0
        return float(jnp.sum(jnp.where(mask, p * jnp.log2(p / jnp.clip(q, 1e-10, 1.0)), 0.0)))
    
    p = jnp.array([0.9, 0.1])
    q = jnp.array([0.5, 0.5])
    
    print(f"D_KL(p || q) = {kl_div(p, q):.4f}")
    print(f"D_KL(q || p) = {kl_div(q, p):.4f}")
    print(f"结果不同!KL divergence 是不对称的。")
    

  4. 模拟训练过程中的 cross-entropy 损失。创建一个"真实"的 one-hot 标签,展示随着模型预测概率提升损失如何下降。

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # 真实标签:4 类中的第 2 类
    true_label = jnp.array([0, 0, 1, 0])
    
    # 模拟预测改进过程
    steps = []
    losses = []
    for confidence in jnp.linspace(0.25, 0.99, 50):
        # 模型对第 2 类越来越有把握
        remaining = (1 - confidence) / 3
        pred = jnp.array([remaining, remaining, confidence, remaining])
        loss = -jnp.sum(true_label * jnp.log(jnp.clip(pred, 1e-10, 1.0)))
        steps.append(float(confidence))
        losses.append(float(loss))
    
    plt.figure(figsize=(8, 4))
    plt.plot(steps, losses, color="#e74c3c", linewidth=2)
    plt.xlabel("模型对真实类别的置信度")
    plt.ylabel("Cross-entropy 损失")
    plt.title("随着预测改善,cross-entropy 损失下降")
    plt.grid(alpha=0.3)
    plt.show()