Skip to content

Statistical Inference(统计推断)

统计推断超越了是/否的决策,通过量化的不确定性来估计总体参数。本节涵盖置信区间、点估计和区间估计、最大似然估计、矩估计以及回归分析,它们是 ML 中原始数据和预测模型之间的桥梁。

  • 假设检验给出一个是/否的决定:拒绝或不拒绝。但通常你需要更具信息量的内容,即你正在估计的参数的一个合理取值范围。这就是置信区间(confidence intervals)所提供的。

  • 点估计(point estimate)是由样本计算出的单一数字,如样本均值 \(\bar{x}\)。它是你对总体参数的最佳猜测,但它本身无法体现估计的精确程度。

  • 置信区间(confidence interval)用一个反映不确定性的范围将该点估计包裹起来。其形式为:

\[\text{CI} = \bar{x} \pm \text{ME}\]
  • 误差幅度(margin of error,ME)取决于三件事:你期望的置信度、数据的变异程度以及样本量的大小:
\[\text{ME} = z^\ast \cdot \frac{\sigma}{\sqrt{n}}\]
  • 这里的 \(z^\ast\) 是来自正态分布的临界值,与你期望的置信水平相匹配。对于 95% 的置信度,\(z^\ast = 1.96\)。对于 99% 的置信度,\(z^\ast = 2.576\)

置信区间:带有两侧误差幅度的点估计

  • 95% 置信区间(95% confidence interval)意味着:如果你多次重复该实验并每次都构建一个区间,那么大约 95% 的区间将包含真实的总体参数。这并不意味着参数在该特定区间内的概率为 95%。参数是固定的;变化的是区间。

  • 例题:你测量了 50 个人的身高,发现 \(\bar{x} = 170\) cm,\(\sigma = 8\) cm。构建一个 95% 置信区间。

\[\text{ME} = 1.96 \cdot \frac{8}{\sqrt{50}} = 1.96 \cdot 1.131 = 2.22 \text{ cm}\]
\[\text{CI} = [170 - 2.22, \; 170 + 2.22] = [167.78, \; 172.22]\]
  • 你可以有 95% 的把握说真实平均身高介于 167.78 和 172.22 厘米之间。

  • \(\sigma\) 未知(通常情况)时,改用样本标准差 \(s\) 和 t 分布:

\[\text{CI} = \bar{x} \pm t^\ast_{n-1} \cdot \frac{s}{\sqrt{n}}\]
  • 更宽的区间置信度更高但精确度较低。更窄的区间精确度更高但置信度较低。你可以通过增加样本量在不降低置信度的情况下缩小区间。

  • 功效分析(power analysis)帮助你在进行实验之前进行规划。问题在于:我需要多大的样本量才能以指定的功效检测出给定大小的效应?

  • 回顾上一节,功效(power) = \(1 - \beta\),即正确拒绝虚假 \(H_0\) 的概率。一个常见的目标是 80% 的功效。

  • 进行 z 检验,以显著性 \(\alpha\) 和功效 \(1-\beta\) 检测出差异 \(\delta\),所需的样本量为:

\[n = \left(\frac{(z_{\alpha/2} + z_{\beta}) \cdot \sigma}{\delta}\right)^2\]
  • 例如,要在 \(\alpha = 0.05\) 和 80% 功效(\(z_{0.025} = 1.96\)\(z_{0.20} = 0.84\))下检测出平均身高(\(\sigma = 8\))2 厘米的差异:
\[n = \left(\frac{(1.96 + 0.84) \cdot 8}{2}\right)^2 = \left(\frac{22.4}{2}\right)^2 = 11.2^2 \approx 126\]
  • 你每组大约需要 126 人。

  • 功效分析可以防止两个常见错误:实验规模太小无法检测出真实效应(功效不足,underpowered),或者将资源浪费在远大于必要规模的实验上(功效过高,overpowered)。

  • 蒙特卡洛方法(Monte Carlo methods)使用随机抽样来解决难以或无法通过解析方法解决的问题。核心思想是:如果你无法精确计算某事,就对其进行多次模拟,并使用结果作为近似值。

  • 这个名字来源于蒙特卡洛赌场,以致敬随机性在其中的作用。这些方法是 ML 中用于估计积分、评估模型不确定性和近似复杂分布等任务的主力。

  • 蒙特卡洛的一般步骤:

    • 定义可能输入的定义域
    • 从该定义域中生成随机输入
    • 对每个输入求函数值
    • 聚合结果(平均、计数等)
  • 一个经典的例子是估计 \(\pi\)。想象一个边长为 2、以原点为中心的正方形,内部内接一个半径为 1 的圆。正方形的面积为 4,圆的面积为 \(\pi\)

带有内接圆的正方形,随机点按在内部/外部着色

  • 在正方形内均匀地随机投点。落在圆内点的比例近似于 \(\pi/4\)
\[\pi \approx 4 \times \frac{\text{圆内点数}}{\text{总点数}}\]
  • 如果 \(x^2 + y^2 \le 1\),则点 \((x, y)\) 在圆内。投掷的点越多,你的估计值就越接近 \(\pi\) 的真实值。

  • 在 ML 中,蒙特卡洛方法出现在:

    • 蒙特卡洛 dropout(Monte Carlo dropout):在启用 dropout 的情况下多次运行推理,以估计预测的不确定性
    • 马尔可夫链蒙特卡洛(MCMC,Markov Chain Monte Carlo):在贝叶斯模型中从复杂的后验分布中抽样
    • 策略梯度方法(Policy gradient methods):在强化学习中通过对轨迹抽样来估计梯度
  • 因子分析(Factor analysis)是一种发现隐藏(潜在)变量的技术,这些变量解释了观测变量之间的相关性。如果 10 个性格调查问题可以由 3 个潜在特质(外向性、宜人性、尽责性)来解释,因子分析就能找出这些特质。

  • 该模型假设每个观测变量 \(x_i\) 是几个潜在因子 \(f_j\) 的线性组合加上噪声:

\[x_i = \lambda_{i1} f_1 + \lambda_{i2} f_2 + \ldots + \lambda_{ik} f_k + \epsilon_i\]
  • \(\lambda\) 值被称为因子载荷(factor loadings),并告诉你每个观测变量与每个因子的关联强度。这与第 2 章中的矩阵分解直接相关;因子分析与特征值分解和 SVD 密切相关。

  • 实验设计(Experimental design)是一门构建实验的艺术,以便你能够得出有效的结论。糟糕的设计甚至会使庞大的数据集变得毫无用处。

  • 精心设计的实验的关键组成部分:

    • 自变量(Independent variable,IV):你操作的内容(例如药物剂量、模型架构)
    • 因变量(Dependent variable,DV):你测量的内容(例如恢复时间、准确率)
    • 对照组(Control group):不接受治疗(或接受安慰剂),为比较提供基线
    • 随机分配(Random assignment):参与者被随机分配到各组,这抵消了你未测量的混杂变量的影响
  • 常见的实验设计

    • 完全随机设计(Completely randomised design):受试者被随机分配到处理组中。当各组具有可比性时,简单有效。
    • 随机区组设计(Randomised block design):受试者首先被分组到区组中(例如按年龄),然后在每个区组内随机分配处理。这减少了区组因素带来的变异性,精神上类似于分层抽样。
    • 析因设计(Factorial design):同时测试多个自变量。一个 \(2 \times 3\) 的析因设计有一个变量的 2 个水平和另一个变量的 3 个水平,共有 6 种处理组合。这让你能检测到交互作用(interactions),即一个变量的效应取决于另一个变量的水平。
    • 交叉设计(Crossover design):每个受试者依次接受所有处理(中间有洗脱期)。每个受试者作为自己的对照,减少了个体差异的影响。
  • 在 ML 实验中,这些原则至关重要。在比较模型时,你应该控制随机种子、数据集划分和硬件。交叉验证是交叉设计的一种形式。每次移除一个组件的消融研究(Ablation studies)遵循析因设计的逻辑。

Coding Tasks(编程练习,使用 CoLab 或 notebook)

  1. 为身高示例构建一个 95% 的置信区间,然后尝试不同的置信水平和样本量。

    import jax.numpy as jnp
    
    x_bar = 170.0    # 样本均值
    sigma = 8.0      # 总体标准差(已知)
    n = 50           # 样本量
    
    # 常见置信水平的临界值
    z_stars = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}
    
    for conf, z_star in z_stars.items():
        me = z_star * (sigma / jnp.sqrt(n))
        lower, upper = x_bar - me, x_bar + me
        print(f"{conf*100:.0f}% CI: [{lower:.2f}, {upper:.2f}]  (误差幅度 ME = {me:.2f})")
    

  2. 使用蒙特卡洛模拟估计 \(\pi\)。绘制随着点数增加,估计值如何收敛的图。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    key = jax.random.PRNGKey(42)
    
    # 在 [-1, 1] x [-1, 1] 生成随机点
    n_points = 100_000
    k1, k2 = jax.random.split(key)
    x = jax.random.uniform(k1, shape=(n_points,), minval=-1, maxval=1)
    y = jax.random.uniform(k2, shape=(n_points,), minval=-1, maxval=1)
    
    # 检查哪些点在单位圆内
    inside = (x**2 + y**2) <= 1.0
    cumulative_inside = jnp.cumsum(inside)
    counts = jnp.arange(1, n_points + 1)
    pi_estimates = 4.0 * cumulative_inside / counts
    
    plt.figure(figsize=(10, 4))
    plt.plot(pi_estimates, color="#3498db", alpha=0.7, linewidth=0.5)
    plt.axhline(y=jnp.pi, color="#e74c3c", linestyle="--", label=f"π = {jnp.pi:.6f}")
    plt.xlabel("点数")
    plt.ylabel("π 的估计值")
    plt.title("蒙特卡洛方法估计 π")
    plt.legend()
    plt.ylim(2.8, 3.5)
    plt.show()
    
    print(f"最终估计值: {pi_estimates[-1]:.6f}")
    print(f"真实值:     {jnp.pi:.6f}")
    print(f"误差:          {abs(pi_estimates[-1] - jnp.pi):.6f}")
    

  3. 执行一个简单的功效分析:给定效应量和标准差,计算所需的样本量并通过模拟验证。

    import jax
    import jax.numpy as jnp
    
    # 参数
    delta = 2.0      # 效应量(均值差异)
    sigma = 8.0      # 总体标准差
    alpha = 0.05
    power_target = 0.80
    
    # 解析法计算样本量
    z_alpha = 1.96   # 双尾,alpha=0.05
    z_beta = 0.84    # 功效 power=0.80
    n_required = ((z_alpha + z_beta) * sigma / delta) ** 2
    print(f"每组所需样本量 n: {n_required:.0f}")
    
    # 通过模拟验证
    key = jax.random.PRNGKey(7)
    n = int(jnp.ceil(n_required))
    n_sims = 5000
    rejections = 0
    
    for _ in range(n_sims):
        key, k1, k2 = jax.random.split(key, 3)
        group_a = jax.random.normal(k1, shape=(n,)) * sigma + 50
        group_b = jax.random.normal(k2, shape=(n,)) * sigma + 50 + delta
        pooled_se = jnp.sqrt(2 * sigma**2 / n)
        z = (group_b.mean() - group_a.mean()) / pooled_se
        p = 2 * (1 - __import__("jax").scipy.stats.norm.cdf(jnp.abs(z)))
        if p <= alpha:
            rejections += 1
    
    print(f"模拟功效: {rejections/n_sims:.3f}")
    print(f"目标功效:    {power_target:.3f}")
    

  4. 可视化置信区间宽度如何随样本量变化。这展示了为什么收集更多数据能提供更精确的估计。

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    sigma = 8.0
    z_star = 1.96  # 95% 置信度
    
    sample_sizes = jnp.array([10, 20, 30, 50, 100, 200, 500, 1000], dtype=jnp.float32)
    margins = z_star * sigma / jnp.sqrt(sample_sizes)
    
    plt.figure(figsize=(8, 4))
    plt.bar([str(int(n)) for n in sample_sizes], margins, color="#3498db", alpha=0.7)
    plt.xlabel("样本量")
    plt.ylabel("误差幅度 (cm)")
    plt.title("95% 置信区间的误差幅度随样本量增大而缩小")
    plt.show()