Introduction to Torch & Einops
Torch
在具体记录一些 Torch 的 common-use functions and classes 之前,我有必要确立 PyTorch 的心智模型 — 当我们在谈论 Torch 的时候,我们在谈论什么。
Einops
Classical Functions Reinvention
LogSumExp
为什么需要计算 LogSumExp ?
实际应用
Softmax 是神经网络中最常见的操作:
两边都取 log:
LogSumExp 就是 Softmax 的分母部分,数值稳定的 Softmax 就依赖它。
Note (数值稳定是什么意思?)
假设我们现在要计算 ,直接计算的问题:
如果 , 会数值溢出(Inf)。
如果 , 会下溢为 0,导致 。
数值稳定的数学技巧 —— 减去最大值
其中 ,这是因为:
为什么这样就稳定了?
- ,所以 ,不会溢出
- 至少有一个 ,所以求和结果 , 不会是
Code
def batched_logsumexp(matrix: Tensor) -> Tensor: c = matrix.max(dim=-1, keepdim=True).values # Tensor.max 返回 torch.return_types.max,有两个字段 values & indices shifted = matrix - c # matrix shape (batch, n) # c shape (batch, 1) # broadcasting c --> (batch, n) return c.squeeze(-1) + shifted.exp().sum(dim=-1).log()Softmax & LogSoftmax
我现在来推导一下数值稳定版的 Softmax 和 LogSoftmax。
LogSoftmax 推导和代码实现
数学推导
首先,Softmax 的原表达式为:
数值稳定版本,我需要算出 LogSumExp,其中 :
给 Softmax 原表达式两边加上 log:
最后将右侧代换,就得到了数值稳定版的 LogSoftmax:
代码实现
def batched_logsoftmax(matrix: Tensor) -> Tensor: """Compute log(softmax(row)) for each row of the matrix.
matrix: shape (batch, n) Return: shape (batch, n) """ # 数值稳定:先减最大值 c = matrix.max(dim=-1, keepdim=True).values # (batch, 1) # 对应 LogSumExp 表达式 log_sum_exp = c + (matrix - c).exp().sum(dim=-1, keepdim=True).log() # keepdim=True 保持 shape 为 (batch, 1)
# 最终结果 LogSoftmax # broadcasting: log_sum_exp 从 (batch, 1) 广播为 (batch, n) return matrix - log_sum_expSoftmax 推导和代码实现
数学推导
从上面已经得出了数值稳定的 LogSoftmax 表示:
此时,加上 exp 可以得到 Softmax 的另一种表达方式:
代码实现
def batched_softmax(matrix: Tensor) -> Tensor: c = matrix.max(dim=-1, keepdim=True).values shifted_exp = (matrix - c).exp() return shifted_exp / shifted_exp.sum(dim=-1, keepdim=True)Cross Entropy Loss
数学推导
交叉熵可以分别从统计角度和信息论的角度来解析和推导,具体的内容可以 refer to ML 的 Foundation 部分(还在施工中)。
Cross Entropy Loss 本质上是在惩罚模型没有把足够大的概率分配给真实分布(类别)。
假设真实的概率分布为 (通常是 one-hot,只有一个类别概率是 1,其余的概率都是 0);
假设模型学习到的概率分布为 。
Information Theory 给的直觉 —— 概率越小的事件,发生时信息量越大,必然事件没有信息量。以此确定的自信息(Self-Information)和交叉熵分别如下所示:
代码实现
def batched_cross_entropy_loss(logits: Tensor, true_labels: Tensor) -> Tensor: """Compute the cross entropy loss for each example in the batch.
logits: shape (batch, classes). logits[i][j] is the unnormalized prediction for example i and class j. true_labels: shape (batch, ). true_labels[i] is an integer index representing the true class for example i.
Return: shape (batch, ). out[i] is the loss for example i. """ log_prob = batched_logsoftmax(logits) # return (batch, classes) return -log_prob[t.arange(len(true_labels)), true_labels]