Sampling(抽样)¶
抽样决定了我们如何收集数据,并直接控制着我们所得出每个结论的质量。本节涵盖随机抽样、分层抽样、整群抽样和系统抽样,以及抽样分布、大数定律和自助法(bootstrapping)——这些方法对 ML 中的训练/测试集划分和数据集整理至关重要。
-
在理想情况下,你会测量所关注群体中的每一个成员。但在实践中,这几乎是不可能的。你无法调查每一位选民、测试每一个灯泡,或扫描每一位患者。因此,你需要采集一个样本(sample),并用它来了解整体。
-
总体(population)是你想研究的完整个体或项目集合。样本(sample)是你实际观测的子集。
-
参数(parameter)是描述总体的数字(例如一个国家所有成年人的真实平均身高)。
-
统计量(statistic)是从样本中计算出的数字(例如你测量的 500 人的平均身高)。统计量用于估计参数。
-
你的结论质量完全取决于你如何选择样本。有偏的样本会导致有偏的结论,无论你的分析有多复杂。
-
抽样框(sampling frame)是你实际从中抽取样本的所有个体的列表。理想情况下,它与总体完全吻合,但在实践中往往存在差距。
-
例如,如果你通过电话调查,你会错过所有没有电话的人。抽样框与总体之间的差异称为覆盖误差(coverage error)。
-
抽样误差(sampling error)是样本统计量与总体参数之间的自然差异。
-
即使是完全随机的样本也不会与总体完全吻合。更大的样本会减少抽样误差。
-
抽样方法分为两大类:概率抽样和非概率抽样。
-
概率抽样(probability sampling)是指总体中每个成员都有已知的、非零的被选中概率。这使你能够量化不确定性并推广结论。
-
简单随机抽样(simple random sampling):每个个体被选中的概率相等,每个大小为 \(n\) 的可能样本被选中的概率也相等。想象把每个名字放入一顶帽子并盲目抽取。
-
分层抽样(stratified sampling):根据共同特征(例如年龄组、地区)将总体划分为不重叠的组(层),然后从每一层中随机抽样。这保证了每个组的代表性,并在各层之间差异较大时降低方差。
-
整群抽样(cluster sampling):将总体划分为若干组(群),随机选择某些群,然后纳入所选群中的所有个体。当总体在地理上分散时,这种方法很实用,例如在一个学区中抽取整所学校而不是单个学生。
-
系统抽样(systematic sampling):选择一个随机起点,然后从列表中每隔 \(k\) 个个体选取一个。例如,从第 7 个人开始,每隔 10 个人取一个(7, 17, 27, ...)。易于实施,但如果列表存在隐藏规律,可能会引入偏差。
-
非概率抽样(non-probability sampling)不能给每个成员一个已知的被选中概率。结果不能被严格推广,但这些方法通常更快、更便宜。
-
便利抽样(convenience sampling):选择最容易接触到的人。在购物中心调查人们很方便,但会遗漏那些不去购物中心的人。
-
配额抽样(quota sampling):类似于分层抽样,但没有随机性。研究者通过从每组中选取容易接触的个体来填满配额(例如 50 名男性和 50 名女性)。
-
滚雪球抽样(snowball sampling):从少数参与者开始,请他们招募其他人。适用于难以接触的群体(例如研究罕见疾病),但严重偏向于人际关系密切的个体。
-
确定抽样方法后,一个自然的问题出现了:如果我抽取不同的样本,会得到不同的统计量吗?几乎肯定会。抽样分布(sampling distribution)是某个统计量(如样本均值)在所有相同大小样本中的分布。
-
想象抽取 1,000 个不同的 30 人样本,并计算每个样本的平均身高。这 1,000 个均值构成一个分布。有些会略高于真实总体均值,有些略低,大多数会聚集在真实值周围。
-
这个抽样分布的标准差称为标准误差(standard error):
-
注意,随着 \(n\) 增大,标准误差缩小。更大的样本给出更精确的估计。样本量翻四倍,标准误差减半。
-
统计学中最重要的结果是中心极限定理(Central Limit Theorem,CLT)。它指出:无论原始总体的形状如何,随着样本量的增加,样本均值的分布趋近于正态分布。
- 更精确地说,如果 \(X_1, X_2, \ldots, X_n\) 是来自任意均值为 \(\mu\)、有限方差为 \(\sigma^2\) 的分布的独立观测值,那么随着 \(n\) 增大:
-
CLT 使大多数推断统计得以运作。它允许我们使用正态分布作为近似,即使底层数据不是正态分布的,只要样本足够大。
-
多大才算"足够大"?一个常见的经验法则是 \(n \ge 30\),但这取决于总体的非正态程度。对于高度偏斜的分布,你可能需要更多。对于大致对称的总体,即使 \(n = 10\) 也可能足够。
-
CLT 有三个关键条件:
- 独立性(independence):每个观测值不得影响其他观测值
- 有限方差(finite variance):总体方差必须存在(排除某些奇异分布)
- 同分布(identical distribution):所有观测值来自同一分布
Coding Tasks(编程练习,使用 CoLab 或 notebook)¶
-
直观演示 CLT:从高度偏斜的分布中抽取样本,计算样本均值,观察均值直方图如何变为钟形。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt key = jax.random.PRNGKey(0) # 指数分布(非常偏斜) population = jax.random.exponential(key, shape=(100_000,)) fig, axes = plt.subplots(1, 4, figsize=(14, 3)) sample_sizes = [1, 5, 30, 100] for ax, n in zip(axes, sample_sizes): keys = jax.random.split(key, 2000) means = jnp.array([jax.random.choice(k, population, shape=(n,)).mean() for k in keys]) ax.hist(means, bins=40, color="#3498db", alpha=0.7, density=True) ax.set_title(f"n = {n}") ax.set_xlim(0, 4) fig.suptitle("CLT:随着 n 增大,样本均值趋近正态分布", fontsize=13) plt.tight_layout() plt.show() -
比较简单随机抽样与分层抽样。创建一个具有不同组的总体,并展示分层抽样在估计中给出更低的方差。
import jax import jax.numpy as jnp key = jax.random.PRNGKey(42) # 总体:两个不同的组 group_a = jax.random.normal(key, shape=(500,)) + 10 # 均值约为 10 key, subkey = jax.random.split(key) group_b = jax.random.normal(subkey, shape=(500,)) + 20 # 均值约为 20 population = jnp.concatenate([group_a, group_b]) # 简单随机抽样:1000 次试验,样本量 20 srs_means = [] for i in range(1000): key, subkey = jax.random.split(key) sample = jax.random.choice(subkey, population, shape=(20,), replace=False) srs_means.append(sample.mean()) srs_means = jnp.array(srs_means) # 分层抽样:每组各取 10 个 strat_means = [] for i in range(1000): key, k1, k2 = jax.random.split(key, 3) s_a = jax.random.choice(k1, group_a, shape=(10,), replace=False) s_b = jax.random.choice(k2, group_b, shape=(10,), replace=False) strat_means.append(jnp.concatenate([s_a, s_b]).mean()) strat_means = jnp.array(strat_means) print(f"简单随机 - 均值: {srs_means.mean():.3f}, 标准差: {srs_means.std():.3f}") print(f"分层抽样 - 均值: {strat_means.mean():.3f}, 标准差: {strat_means.std():.3f}") print(f"分层抽样将方差降低了 {(1 - strat_means.var()/srs_means.var())*100:.1f}%") -
探讨样本量如何影响标准误差。绘制标准误差与样本量的关系图,并验证 \(1/\sqrt{n}\) 关系。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt key = jax.random.PRNGKey(7) population = jax.random.normal(key, shape=(50_000,)) * 10 + 50 sample_sizes = [5, 10, 20, 50, 100, 200, 500, 1000] std_errors = [] for n in sample_sizes: means = [] for _ in range(500): key, subkey = jax.random.split(key) sample = jax.random.choice(subkey, population, shape=(n,)) means.append(sample.mean()) std_errors.append(jnp.array(means).std()) plt.figure(figsize=(8, 4)) plt.plot(sample_sizes, std_errors, "o-", color="#e74c3c", label="观测 SE") theoretical = population.std() / jnp.sqrt(jnp.array(sample_sizes, dtype=jnp.float32)) plt.plot(sample_sizes, theoretical, "--", color="#3498db", label="σ/√n(理论值)") plt.xlabel("样本量 (n)") plt.ylabel("标准误差") plt.legend() plt.title("标准误差随样本量增大而缩小") plt.show()