Skip to content

Differential Calculus(微分学)

微分学捕捉瞬时变化率。本节涵盖极限、derivative、微分规则、链式法则(反向传播的基础),以及 ML 中常用的 derivative——这些是整个 ML 的核心工具。

  • 在前几章中,我们学习了如何将数据表示为 vector 并用 matrix 进行变换。但许多现实世界的现象并非静态的。汽车加速、股价波动、神经网络的 loss 随权重更新而变化。Calculus(微积分)是研究变化的数学。

  • Calculus 回答两个问题:某事物现在变化有多快?(微分学)以及它随时间积累了多少?(积分学)本节讨论"变化有多快"这个问题。

  • 想象你在开车,瞥了一眼速度表,显示 60 km/h。这个数字不是你整个旅程的平均速度,而是你此刻的速度。微分学给了我们计算这种瞬时变化率的工具。

  • 但首先,让我们回顾一下直线方程:\(y = mx + b\)

  • 这是两个量之间最简单的关系。

    • \(b\)y 截距,即直线与 y 轴的交点(当 \(x = 0\) 时的起始值)。
    • \(m\)斜率,即变化率:\(x\) 每增加 1 个单位,\(y\) 就变化 \(m\)
    • \(m = 3\),直线急剧上升;若 \(m = 0\),直线是水平的;若 \(m = -2\),直线下降。
  • 斜率计算为 \(m = \frac{\Delta y}{\Delta x} = \frac{y_2 - y_1}{x_2 - x_1}\),即"\(y\) 变化了多少"与"\(x\) 变化了多少"之比。

直线方程:b 是 y 截距,m 是斜率(上升量除以前进量)

  • 一旦知道 \(m\)\(b\),就可以对任意 \(x\) 计算 \(y\)

  • 例如,若 \(m = 2\)\(b = 3\),则当 \(x = 5\) 时:\(y = 2(5) + 3 = 13\)

  • 这两个参数完全确定了直线,预测任意输出只需代入即可。

  • 对于直线,斜率处处相同。

  • 这个思想可以推广到直线以外。任何函数都是将输入映射到输出的规则,一旦知道其公式(参数和形状),就能计算任意输入的输出并绘制结果。

  • \(y = x^2\) 给出抛物线,\(y = \sin(x)\) 给出波形,\(y = e^x\) 给出指数增长。每个公式定义一条特定的曲线,能够通过形状来读懂函数是后续一切内容的基础。

  • 对于直线,斜率处处相同。但大多数有趣的函数是曲线,斜率因点而异。Calculus 给了我们一种方法,能在曲线上的任意单点处求斜率。

  • 我们还需要极限的概念。极限描述的是:当函数的输入越来越接近某个目标时,函数值趋向于什么,而不必真的达到该目标。

\[\lim_{x \to a} f(x) = L\]
  • 这读作:"当 \(x\) 趋近 \(a\) 时,\(f(x)\) 趋近 \(L\)。"函数不需要在 \(x = a\) 处等于 \(L\),只需任意接近即可。

  • 例如,取 \(f(x) = \frac{x^2 - 1}{x - 1}\)。直接代入 \(x = 1\) 会得到 \(\frac{0}{0}\),这是未定义的。

  • 但试试接近 1 的值:\(f(0.9) = 1.9\)\(f(0.99) = 1.99\)\(f(1.01) = 2.01\)。输出显然趋向 2。

  • 从代数上可以看出原因:将分子分解为 \((x-1)(x+1)\),约去 \((x-1)\),对所有 \(x \neq 1\)\(f(x) = x + 1\)。因此当 \(x \to 1\) 时,\(f(x) \to 2\)

  • 该函数在 \(x = 1\) 处有一个"洞",但极限仍然存在。

  • 极限是 calculus 中一切其他内容的基础。

  • 函数 \(f(x)\) 在点 \(x = a\) 处的 derivative 度量的是瞬时变化率。从几何上看,它是曲线在该点处切线的斜率。

Derivative 是曲线上某点处切线的斜率

  • 为了计算这个斜率,我们从曲线上取两个点,计算过这两点的直线(割线)的斜率,然后让第二个点越来越靠近第一个点,观察割线斜率趋近于什么值。这就是差商
\[f'(a) = \lim_{h \to 0} \frac{f(a + h) - f(a)}{h}\]

随着 h 缩小,割线趋近于切线

  • 分子 \(f(a+h) - f(a)\) 是输出的变化量。分母 \(h\) 是输入的变化量。它们的比值是在微小区间上的平均变化率。当 \(h \to 0\) 时,这个平均值变为瞬时变化率。

  • 例如,令 \(f(x) = x^2\),在 \(x = 3\) 处:

\[f'(3) = \lim_{h \to 0} \frac{(3+h)^2 - 9}{h} = \lim_{h \to 0} \frac{9 + 6h + h^2 - 9}{h} = \lim_{h \to 0} (6 + h) = 6\]
  • 所以在 \(x = 3\) 处,函数 \(x^2\) 的变化率是每单位输入变化 6 个单位输出。

  • 若函数在某点处该极限存在,则称其在该点可微。为此,函数必须连续(无跳变)、光滑(无尖角),且在该点邻域内有定义。

  • 如果你能不抬笔、不转折地画出该曲线,它在那里很可能是可微的。

  • 每次都从极限定义来计算 derivative 会很繁琐。幸运的是,几条规则让我们能快速对几乎任何函数求导。

  • 常数规则:常数的 derivative 为零。若 \(f(x) = 5\),则 \(f'(x) = 0\)。水平线的斜率为零。

  • 幂次规则:微分的主力。将指数移到前面,指数减一:

\[\frac{d}{dx} x^n = n x^{n-1}\]
  • 例如:\(\frac{d}{dx} x^3 = 3x^2\)。三次变为二次。这对任意实数指数均适用,包括负数和分数:\(\frac{d}{dx} x^{-1} = -x^{-2}\)\(\frac{d}{dx} \sqrt{x} = \frac{d}{dx} x^{1/2} = \frac{1}{2}x^{-1/2}\)

  • 和差规则:逐项求导。

\[\frac{d}{dx}[f(x) \pm g(x)] = f'(x) \pm g'(x)\]
  • 乘积规则:两个函数相乘时,derivative 并非简单地是各 derivative 之积。而是:
\[\frac{d}{dx}[f(x) \cdot g(x)] = f'(x)g(x) + f(x)g'(x)\]
  • 可以理解为:"第一个的变化率乘以第二个,加上第一个乘以第二个的变化率。"例如,\(\frac{d}{dx}[x^2 \sin x] = 2x \sin x + x^2 \cos x\)

  • 商法则:对函数之比:

\[\frac{d}{dx}\left[\frac{f(x)}{g(x)}\right] = \frac{f'(x)g(x) - f(x)g'(x)}{[g(x)]^2}\]
  • 一个有用的记忆口诀:"分母乘分子的导数,减分子乘分母的导数,除以分母的平方。"

  • 链式法则:ML 中最重要的规则。当函数复合(一个套在另一个里面)时,derivative 是沿链各 derivative 的乘积:

\[\frac{d}{dx} f(g(x)) = f'(g(x)) \cdot g'(x)\]
  • 把它想象成剥洋葱。先对外层函数求导(保持内层函数不变),再乘以内层函数的 derivative。

链式法则:对外层求导,再乘以内层的 derivative

  • 例如,\(\frac{d}{dx} (3x + 1)^5 = 5(3x+1)^4 \cdot 3 = 15(3x+1)^4\)。外层函数是 \((\cdot)^5\),内层函数是 \(3x+1\)

  • 链式法则是神经网络反向传播的数学基础。深层网络是一长串复合函数。为了计算 loss 关于每个权重的变化,我们从输出 layer 到输入逐层反向应用链式法则,在每一步将局部 derivative 相乘。

  • 以下是你会遇到的最常见的 derivative。每一个都可以从极限定义推导出来,但熟记它们能节省时间:

函数 Derivative 备注
\(e^x\) \(e^x\) 唯一等于自身 derivative 的函数
\(a^x\) \(a^x \ln a\) 推广了指数
\(\ln x\) \(\frac{1}{x}\) 自然对数
\(\log_a x\) \(\frac{1}{x \ln a}\) 一般对数
\(\sin x\) \(\cos x\)
\(\cos x\) \(-\sin x\) 注意负号
\(\tan x\) \(\sec^2 x\)
  • 指数函数 \(e^x\) 很特别:它是唯一等于自身 derivative 的函数。这就是为什么 \(e\) 在 ML 中无处不在,从 softmax activation 到概率分布。

  • 洛必达法则(L'Hopital's Rule)处理产生 \(\frac{0}{0}\)\(\frac{\infty}{\infty}\) 等不定形式的极限。当直接代入给出这些不定形式时,可以分别对分子和分母求导,然后再求极限:

\[\lim_{x \to a} \frac{f(x)}{g(x)} = \lim_{x \to a} \frac{f'(x)}{g'(x)}\]
  • 条件:\(f\)\(g\)\(a\) 附近必须可微,且 \(g'(x) \neq 0\)\(a\) 点除外)。原极限必须产生不定形式。

  • 例如:\(\lim_{x \to 0} \frac{\sin x}{x}\)。直接代入给出 \(\frac{0}{0}\)。应用洛必达法则:\(\lim_{x \to 0} \frac{\cos x}{1} = 1\)。这个极限很基本——它出现在信号处理和 Fourier 分析中。

  • 若结果仍为不定形式,可以反复应用该法则。例如,\(\lim_{x \to 0} \frac{1 - \cos x}{x^2}\) 给出 \(\frac{0}{0}\)。第一次应用:\(\lim_{x \to 0} \frac{\sin x}{2x}\),仍为 \(\frac{0}{0}\)。第二次应用:\(\lim_{x \to 0} \frac{\cos x}{2} = \frac{1}{2}\)

  • 如果两个函数可微,它们的和、差、积、复合以及商(分母非零时)也都是可微的。这就是为什么我们能有把握地对由简单部分构成的复杂表达式求导。

编程练习(使用 CoLab 或 notebook)

  1. 可视化常见函数。并排绘制 \(x^2\)\(\sin(x)\)\(e^x\),建立对不同公式如何产生不同形状的直觉。尝试修改参数(如 \(2x^2\)\(\sin(2x)\)),观察曲线如何变化。

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    x = jnp.linspace(-3, 3, 300)
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 3))
    axes[0].plot(x, x**2, color="#e74c3c")
    axes[0].set_title("x²(抛物线)")
    axes[1].plot(x, jnp.sin(x), color="#3498db")
    axes[1].set_title("sin(x)(波形)")
    axes[2].plot(x, jnp.exp(x), color="#27ae60")
    axes[2].set_title("eˣ(指数)")
    for ax in axes:
        ax.axhline(0, color="gray", linewidth=0.5)
        ax.axvline(0, color="gray", linewidth=0.5)
    plt.tight_layout()
    plt.show()
    

  2. 使用 JAX 的自动微分计算 \(f(x) = x^3 - 2x + 1\) 在多个点处的 derivative。与解析 derivative \(f'(x) = 3x^2 - 2\) 进行比较。

    import jax
    import jax.numpy as jnp
    
    f = lambda x: x**3 - 2*x + 1
    df = jax.grad(f)
    
    for x in [0.0, 1.0, 2.0, -1.0]:
        print(f"x={x:5.1f}  自动微分: {df(x):.4f}  解析值: {3*x**2 - 2:.4f}")
    

  3. 数值验证链式法则。定义 \(f(x) = \sin(x^2)\),通过 jax.grad 计算其 derivative,并与解析结果 \(2x\cos(x^2)\) 比较。

    import jax
    import jax.numpy as jnp
    
    f = lambda x: jnp.sin(x**2)
    df = jax.grad(f)
    
    for x in [0.5, 1.0, 2.0]:
        auto = df(x)
        analytical = 2*x * jnp.cos(x**2)
        print(f"x={x:.1f}  自动微分: {auto:.6f}  解析值: {analytical:.6f}")
    

  4. 可视化 derivative。在同一图上绘制 \(f(x) = x^3 - 3x\) 及其 derivative \(f'(x) = 3x^2 - 3\)。注意 \(f'(x) = 0\) 的位置对应 \(f\) 的极大值和极小值。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    f = lambda x: x**3 - 3*x
    # jax.grad 作用于 scalar;jax.vmap 将其向量化,以便对输入数组逐元素运算
    df = jax.vmap(jax.grad(f))
    
    x = jnp.linspace(-2.5, 2.5, 200)
    plt.plot(x, jax.vmap(f)(x), label="f(x)")
    plt.plot(x, df(x), label="f'(x)", linestyle="--")
    plt.axhline(0, color="gray", linewidth=0.5)
    plt.legend()
    plt.title("函数及其 derivative")
    plt.show()