Differential Calculus(微分学)¶
微分学捕捉瞬时变化率。本节涵盖极限、derivative、微分规则、链式法则(反向传播的基础),以及 ML 中常用的 derivative——这些是整个 ML 的核心工具。
-
在前几章中,我们学习了如何将数据表示为 vector 并用 matrix 进行变换。但许多现实世界的现象并非静态的。汽车加速、股价波动、神经网络的 loss 随权重更新而变化。Calculus(微积分)是研究变化的数学。
-
Calculus 回答两个问题:某事物现在变化有多快?(微分学)以及它随时间积累了多少?(积分学)本节讨论"变化有多快"这个问题。
-
想象你在开车,瞥了一眼速度表,显示 60 km/h。这个数字不是你整个旅程的平均速度,而是你此刻的速度。微分学给了我们计算这种瞬时变化率的工具。
-
但首先,让我们回顾一下直线方程:\(y = mx + b\)。
-
这是两个量之间最简单的关系。
- \(b\) 是 y 截距,即直线与 y 轴的交点(当 \(x = 0\) 时的起始值)。
- \(m\) 是斜率,即变化率:\(x\) 每增加 1 个单位,\(y\) 就变化 \(m\)。
- 若 \(m = 3\),直线急剧上升;若 \(m = 0\),直线是水平的;若 \(m = -2\),直线下降。
-
斜率计算为 \(m = \frac{\Delta y}{\Delta x} = \frac{y_2 - y_1}{x_2 - x_1}\),即"\(y\) 变化了多少"与"\(x\) 变化了多少"之比。
-
一旦知道 \(m\) 和 \(b\),就可以对任意 \(x\) 计算 \(y\)。
-
例如,若 \(m = 2\),\(b = 3\),则当 \(x = 5\) 时:\(y = 2(5) + 3 = 13\)。
-
这两个参数完全确定了直线,预测任意输出只需代入即可。
-
对于直线,斜率处处相同。
-
这个思想可以推广到直线以外。任何函数都是将输入映射到输出的规则,一旦知道其公式(参数和形状),就能计算任意输入的输出并绘制结果。
-
\(y = x^2\) 给出抛物线,\(y = \sin(x)\) 给出波形,\(y = e^x\) 给出指数增长。每个公式定义一条特定的曲线,能够通过形状来读懂函数是后续一切内容的基础。
-
对于直线,斜率处处相同。但大多数有趣的函数是曲线,斜率因点而异。Calculus 给了我们一种方法,能在曲线上的任意单点处求斜率。
-
我们还需要极限的概念。极限描述的是:当函数的输入越来越接近某个目标时,函数值趋向于什么,而不必真的达到该目标。
-
这读作:"当 \(x\) 趋近 \(a\) 时,\(f(x)\) 趋近 \(L\)。"函数不需要在 \(x = a\) 处等于 \(L\),只需任意接近即可。
-
例如,取 \(f(x) = \frac{x^2 - 1}{x - 1}\)。直接代入 \(x = 1\) 会得到 \(\frac{0}{0}\),这是未定义的。
-
但试试接近 1 的值:\(f(0.9) = 1.9\),\(f(0.99) = 1.99\),\(f(1.01) = 2.01\)。输出显然趋向 2。
-
从代数上可以看出原因:将分子分解为 \((x-1)(x+1)\),约去 \((x-1)\),对所有 \(x \neq 1\) 得 \(f(x) = x + 1\)。因此当 \(x \to 1\) 时,\(f(x) \to 2\)。
-
该函数在 \(x = 1\) 处有一个"洞",但极限仍然存在。
-
极限是 calculus 中一切其他内容的基础。
-
函数 \(f(x)\) 在点 \(x = a\) 处的 derivative 度量的是瞬时变化率。从几何上看,它是曲线在该点处切线的斜率。
- 为了计算这个斜率,我们从曲线上取两个点,计算过这两点的直线(割线)的斜率,然后让第二个点越来越靠近第一个点,观察割线斜率趋近于什么值。这就是差商:
-
分子 \(f(a+h) - f(a)\) 是输出的变化量。分母 \(h\) 是输入的变化量。它们的比值是在微小区间上的平均变化率。当 \(h \to 0\) 时,这个平均值变为瞬时变化率。
-
例如,令 \(f(x) = x^2\),在 \(x = 3\) 处:
-
所以在 \(x = 3\) 处,函数 \(x^2\) 的变化率是每单位输入变化 6 个单位输出。
-
若函数在某点处该极限存在,则称其在该点可微。为此,函数必须连续(无跳变)、光滑(无尖角),且在该点邻域内有定义。
-
如果你能不抬笔、不转折地画出该曲线,它在那里很可能是可微的。
-
每次都从极限定义来计算 derivative 会很繁琐。幸运的是,几条规则让我们能快速对几乎任何函数求导。
-
常数规则:常数的 derivative 为零。若 \(f(x) = 5\),则 \(f'(x) = 0\)。水平线的斜率为零。
-
幂次规则:微分的主力。将指数移到前面,指数减一:
-
例如:\(\frac{d}{dx} x^3 = 3x^2\)。三次变为二次。这对任意实数指数均适用,包括负数和分数:\(\frac{d}{dx} x^{-1} = -x^{-2}\),\(\frac{d}{dx} \sqrt{x} = \frac{d}{dx} x^{1/2} = \frac{1}{2}x^{-1/2}\)。
-
和差规则:逐项求导。
- 乘积规则:两个函数相乘时,derivative 并非简单地是各 derivative 之积。而是:
-
可以理解为:"第一个的变化率乘以第二个,加上第一个乘以第二个的变化率。"例如,\(\frac{d}{dx}[x^2 \sin x] = 2x \sin x + x^2 \cos x\)。
-
商法则:对函数之比:
-
一个有用的记忆口诀:"分母乘分子的导数,减分子乘分母的导数,除以分母的平方。"
-
链式法则:ML 中最重要的规则。当函数复合(一个套在另一个里面)时,derivative 是沿链各 derivative 的乘积:
- 把它想象成剥洋葱。先对外层函数求导(保持内层函数不变),再乘以内层函数的 derivative。
-
例如,\(\frac{d}{dx} (3x + 1)^5 = 5(3x+1)^4 \cdot 3 = 15(3x+1)^4\)。外层函数是 \((\cdot)^5\),内层函数是 \(3x+1\)。
-
链式法则是神经网络反向传播的数学基础。深层网络是一长串复合函数。为了计算 loss 关于每个权重的变化,我们从输出 layer 到输入逐层反向应用链式法则,在每一步将局部 derivative 相乘。
-
以下是你会遇到的最常见的 derivative。每一个都可以从极限定义推导出来,但熟记它们能节省时间:
| 函数 | Derivative | 备注 |
|---|---|---|
| \(e^x\) | \(e^x\) | 唯一等于自身 derivative 的函数 |
| \(a^x\) | \(a^x \ln a\) | 推广了指数 |
| \(\ln x\) | \(\frac{1}{x}\) | 自然对数 |
| \(\log_a x\) | \(\frac{1}{x \ln a}\) | 一般对数 |
| \(\sin x\) | \(\cos x\) | |
| \(\cos x\) | \(-\sin x\) | 注意负号 |
| \(\tan x\) | \(\sec^2 x\) |
-
指数函数 \(e^x\) 很特别:它是唯一等于自身 derivative 的函数。这就是为什么 \(e\) 在 ML 中无处不在,从 softmax activation 到概率分布。
-
洛必达法则(L'Hopital's Rule)处理产生 \(\frac{0}{0}\) 或 \(\frac{\infty}{\infty}\) 等不定形式的极限。当直接代入给出这些不定形式时,可以分别对分子和分母求导,然后再求极限:
-
条件:\(f\) 和 \(g\) 在 \(a\) 附近必须可微,且 \(g'(x) \neq 0\)(\(a\) 点除外)。原极限必须产生不定形式。
-
例如:\(\lim_{x \to 0} \frac{\sin x}{x}\)。直接代入给出 \(\frac{0}{0}\)。应用洛必达法则:\(\lim_{x \to 0} \frac{\cos x}{1} = 1\)。这个极限很基本——它出现在信号处理和 Fourier 分析中。
-
若结果仍为不定形式,可以反复应用该法则。例如,\(\lim_{x \to 0} \frac{1 - \cos x}{x^2}\) 给出 \(\frac{0}{0}\)。第一次应用:\(\lim_{x \to 0} \frac{\sin x}{2x}\),仍为 \(\frac{0}{0}\)。第二次应用:\(\lim_{x \to 0} \frac{\cos x}{2} = \frac{1}{2}\)。
-
如果两个函数可微,它们的和、差、积、复合以及商(分母非零时)也都是可微的。这就是为什么我们能有把握地对由简单部分构成的复杂表达式求导。
编程练习(使用 CoLab 或 notebook)¶
-
可视化常见函数。并排绘制 \(x^2\)、\(\sin(x)\) 和 \(e^x\),建立对不同公式如何产生不同形状的直觉。尝试修改参数(如 \(2x^2\)、\(\sin(2x)\)),观察曲线如何变化。
import jax.numpy as jnp import matplotlib.pyplot as plt x = jnp.linspace(-3, 3, 300) fig, axes = plt.subplots(1, 3, figsize=(12, 3)) axes[0].plot(x, x**2, color="#e74c3c") axes[0].set_title("x²(抛物线)") axes[1].plot(x, jnp.sin(x), color="#3498db") axes[1].set_title("sin(x)(波形)") axes[2].plot(x, jnp.exp(x), color="#27ae60") axes[2].set_title("eˣ(指数)") for ax in axes: ax.axhline(0, color="gray", linewidth=0.5) ax.axvline(0, color="gray", linewidth=0.5) plt.tight_layout() plt.show() -
使用 JAX 的自动微分计算 \(f(x) = x^3 - 2x + 1\) 在多个点处的 derivative。与解析 derivative \(f'(x) = 3x^2 - 2\) 进行比较。
-
数值验证链式法则。定义 \(f(x) = \sin(x^2)\),通过
jax.grad计算其 derivative,并与解析结果 \(2x\cos(x^2)\) 比较。 -
可视化 derivative。在同一图上绘制 \(f(x) = x^3 - 3x\) 及其 derivative \(f'(x) = 3x^2 - 3\)。注意 \(f'(x) = 0\) 的位置对应 \(f\) 的极大值和极小值。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt f = lambda x: x**3 - 3*x # jax.grad 作用于 scalar;jax.vmap 将其向量化,以便对输入数组逐元素运算 df = jax.vmap(jax.grad(f)) x = jnp.linspace(-2.5, 2.5, 200) plt.plot(x, jax.vmap(f)(x), label="f(x)") plt.plot(x, df(x), label="f'(x)", linestyle="--") plt.axhline(0, color="gray", linewidth=0.5) plt.legend() plt.title("函数及其 derivative") plt.show()