Skip to content

Matrix Operations(矩阵运算)

Matrix 运算是深度学习的计算引擎。本节涵盖 matrix 加法、scalar 乘法、matrix-vector 乘积、matrix 乘法、逐元素运算、Kronecker 积和广播——这些运算支撑着每一次前向传播和梯度更新。

  • Matrix 可以像 vector 一样进行加法和缩放。

  • 加法时,两个 matrix 必须具有相同的维度,逐元素相加:

\[ \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} + \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} = \begin{bmatrix} 6 & 8 \\ 10 & 12 \end{bmatrix} \]
  • Scalar 乘法时,将每个元素乘以该 scalar:
\[ 3 \times \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} = \begin{bmatrix} 3 & 6 \\ 9 & 12 \end{bmatrix} \]
  • 对 matrix 最基本的操作是将其与一个 vector 相乘。Matrix-vector 乘法 \(A\mathbf{x}\)\(\mathbf{x}\) 的元素作为权重,对 \(A\) 的各列进行线性组合:
\[ \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \begin{bmatrix} 5 \\ 6 \end{bmatrix} = 5 \begin{bmatrix} 1 \\ 3 \end{bmatrix} + 6 \begin{bmatrix} 2 \\ 4 \end{bmatrix} = \begin{bmatrix} 17 \\ 39 \end{bmatrix} \]
  • 这是 ML 中的核心运算。每个神经网络 layer 都计算 \(A\mathbf{x} + \mathbf{b}\):matrix 乘以输入 vector,再加上 bias。

  • 更一般的情况是 matrix 乘法。给定 \(A\)\(m \times n\))和 \(B\)\(n \times p\)),乘积 \(C = AB\) 是一个 \(m \times p\) 的 matrix,其中每个元素是一个 dot product:

\[C_{ij} = \sum_{k=1}^{n} A_{ik} B_{kj}\]
  • 结果中的每个元素是 \(A\) 的某行与 \(B\) 的某列的 dot product。内维度必须匹配(\(n\)),结果取外维度(\(m \times p\))。

  • 另一种理解方式:结果的每一列是 \(A\) 各列的加权和,权重来自 \(B\) 对应列。

  • \(B\) 的某列为 \([2, 3]^T\),则结果对应列为 \(2 \times (A \text{ 的第 1 列}) + 3 \times (A \text{ 的第 2 列})\)

  • 一个有用的特例:matrix 乘以其转置总给出方阵。\(AA^T\)\(m \times m\)\(A^TA\)\(n \times n\)

\[ \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 & 4 \\ 2 & 5 \\ 3 & 6 \end{bmatrix} = \begin{bmatrix} 14 & 32 \\ 32 & 77 \end{bmatrix} \]
  • Matrix 乘法有重要的规则:

    • 不满足交换律:一般情况下 \(AB \neq BA\)。顺序很重要。

    • 满足结合律\((AB)C = A(BC)\)。可以按任意方式分组。

    • 满足分配律\(A(B + C) = AB + AC\)

    • 单位元\(AI = IA = A\)

  • Hadamard 积(逐元素积)对相同大小的两个 matrix 逐元素相乘,写作 \(A \odot B\)

\[ \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \odot \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} = \begin{bmatrix} 5 & 12 \\ 21 & 32 \end{bmatrix} \]
  • 与标准 matrix 乘法不同,Hadamard 积满足交换律(\(A \odot B = B \odot A\))且要求两个 matrix 具有相同的维度。它在 ML 中大量用于门控机制:用介于 0 和 1 之间的掩码逐元素相乘,控制每个元素"通过"的比例。

  • 两个 vector \(\mathbf{u}\)\(\mathbf{v}\)outer product 生成一个 matrix:\(\mathbf{u}\mathbf{v}^T\)。每个元素是 \(\mathbf{u}\) 的一个元素与 \(\mathbf{v}\) 的一个元素的乘积:

\[ \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} \begin{bmatrix} 4 & 5 \end{bmatrix} = \begin{bmatrix} 4 & 5 \\ 8 & 10 \\ 12 & 15 \end{bmatrix} \]
  • 结果始终是 rank 为 1 的 matrix,因为每一行都是 \(\mathbf{v}^T\) 的缩放版本。任何 matrix 都可以写成 rank-1 outer product 的和——这正是 SVD 所做的(详见分解章节)。

  • Matrix 乘法在计算上代价高昂。两个 \(n \times n\) matrix 相乘需要 \(O(n^3)\) 次运算。对于 \(1000 \times 1000\) 的 matrix,这意味着十亿次乘法。

  • 当 matrix 稀疏(大多数元素为零)时,朴素乘法会浪费时间去乘以零。压缩稀疏行(CSR)格式只存储非零元素及其位置:

    • 数值(Values):按行顺序排列的非零元素
    • 列索引(Column indices):每个值所在的列
    • 行偏移(Row offsets):每行在数值列表中的起始位置
  • 例如,matrix:

\[ A = \begin{bmatrix} 5 & 0 & 0 & 2 \\ 0 & 0 & 3 & 0 \\ 0 & 0 & 0 & -1 \end{bmatrix} \]
  • 存储为:数值 = [5, 2, 3, -1],列索引 = [0, 3, 2, 3],行偏移 = [0, 2, 3, 4]。这样跳过了所有零,使稀疏运算快得多。

  • Matrix 的核心用途之一是求解线性方程组。方程组 \(A\mathbf{x} = \mathbf{b}\) 问的是:"什么 vector \(\mathbf{x}\) 在被 \(A\) 变换后能产生 \(\mathbf{b}\)?"

  • 例如,假设你在购买水果。每个苹果价格 \(x_1\) 元,每根香蕉价格 \(x_2\) 元。已知 2 个苹果加 1 根香蕉花费 5 元,1 个苹果加 3 根香蕉花费 10 元。用 matrix 形式:

\[ \begin{bmatrix} 2 & 1 \\ 1 & 3 \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} = \begin{bmatrix} 5 \\ 10 \end{bmatrix} \]
  • 将 matrix 逐行乘以 vector(每行与 \([x_1, x_2]^T\) 做 dot product)得到两个方程:
\[2x_1 + 1x_2 = 5 \qquad \text{(第 1 行)} \qquad \qquad x_1 + 3x_2 = 10 \qquad \text{(第 2 行)}\]
  • 由第 1 行,\(x_2 = 5 - 2x_1\)。代入第 2 行:\(x_1 + 3(5 - 2x_1) = 10\),得 \(x_1 = 1\),则 \(x_2 = 3\)。苹果 1 元,香蕉 3 元。

  • 验证——结果正确:

\[ \begin{bmatrix} 2 & 1 \\ 1 & 3 \end{bmatrix} \begin{bmatrix} 1 \\ 3 \end{bmatrix} = \begin{bmatrix} 2 + 3 \\ 1 + 9 \end{bmatrix} = \begin{bmatrix} 5 \\ 10 \end{bmatrix} \]
  • \(A\) 有逆,解就是 \(\mathbf{x} = A^{-1}\mathbf{b}\)。但直接计算逆代价高且数值不稳定。实际中我们使用分解方法。

  • 并非每个 matrix 都是方阵,也并非每个方阵都可逆。伪逆(pseudo-inverse) \(A^+\) 将逆推广到任意 matrix。它总是存在的,并提供"最优的近似逆":

\[A^+ = (A^TA)^{-1}A^T\]
  • \(A\) 是下三角 matrix 时,通过前代(forward substitution)可以轻松求解 \(L\mathbf{x} = \mathbf{b}\):先求 \(x_1\),然后用它求 \(x_2\),依此向下。

  • \(A\) 是上三角 matrix 时,通过回代(back substitution)求解 \(U\mathbf{x} = \mathbf{b}\):先求最后一个变量,然后向上逐一求解。

  • 这就是为什么将 matrix 分解为三角因子(如分解章节所介绍的)如此有用。它将一个困难问题转化为两个简单问题。

编程练习(使用 CoLab 或 notebook)

  1. 将两个 matrix 相乘并验证其维度。然后交换顺序,观察结果改变(或因维度不匹配而失败)。

    import jax.numpy as jnp
    
    A = jnp.array([[1.0, 2.0],
                   [3.0, 4.0]])
    B = jnp.array([[5.0, 6.0],
                   [7.0, 8.0]])
    
    print(f"A @ B:\n{A @ B}")
    print(f"B @ A:\n{B @ A}")
    print(f"相等: {jnp.allclose(A @ B, B @ A)}")
    

  2. 求解线性方程组 \(A\mathbf{x} = \mathbf{b}\),通过回代乘法验证解。尝试改变 \(\mathbf{b}\),观察解如何变化。

    import jax.numpy as jnp
    
    A = jnp.array([[2.0, 1.0],
                   [5.0, 3.0]])
    b = jnp.array([4.0, 7.0])
    
    x = jnp.linalg.solve(A, b)
    print(f"解 x: {x}")
    print(f"A @ x: {A @ x}")