PyTorch Loss函数深度解析
引言
损失函数(Loss Function)是深度学习中至关重要的组成部分,它衡量模型预测值与真实值之间的差异,指导模型参数的优化方向。PyTorch提供了丰富的损失函数库,本文将深入解析12种常用损失函数的数学原理、应用场景和代码实践。
1. nn.MSELoss - 均方误差损失
数学公式
均方误差损失(Mean Squared Error Loss)计算预测值与真实值之间差的平方的均值:
$$
\text{MSE}(x, y) = \frac{1}{n} \sum_{i=1}^{n} (x_i - y_i)^2
$$
在PyTorch中,还可以设置reduction参数:
'mean'(默认):返回损失的平均值'sum':返回损失的总和'none':返回每个元素的损失
用途场景
- 回归问题:预测连续值,如房价预测、股票价格预测
- 图像重建:自编码器、图像超分辨率任务
- 语音合成:生成模型的重建损失
- 信号处理:噪声消除、信号恢复
代码示例
1 | |
注意事项
- 对异常值敏感:由于平方操作,异常值会产生较大的梯度,可能导致训练不稳定
- 梯度消失问题:当预测值与真实值接近时,梯度会变得很小
- 数值稳定性:对于非常大的值,平方可能导致数值溢出
- 学习率选择:由于梯度与误差成正比,需要合理设置学习率
2. nn.CrossEntropyLoss - 交叉熵损失
数学公式
交叉熵损失(Cross-Entropy Loss)结合了LogSoftmax和NLLLoss:
$$
\text{CE}(x, y) = -\sum_{c=1}^{C} y_c \log(\text{softmax}(x)_c)
$$
对于单标签分类(真实标签为one-hot编码):
$$
\text{CE}(x, y) = -\log\left(\frac{e^{x_{y}}}{\sum_{j} e^{x_j}}\right)
$$
用途场景
- 多分类任务:图像分类(ImageNet、CIFAR)、文本分类
- 语言模型:预测下一个词的概率分布
- 命名实体识别:序列标注任务
- 目标检测:类别预测分支
代码示例
1 | |
注意事项
- 不要在输入前加Softmax:该损失函数内部已经包含了Softmax操作
- 目标标签格式:类别索引(非one-hot编码)
- 类别不平衡:使用
weight参数处理不平衡数据集 - 标签平滑:使用
label_smoothing参数防止过拟合
3. nn.BCELoss / nn.BCEWithLogitsLoss - 二分类交叉熵
数学公式
二分类交叉熵(Binary Cross-Entropy):
$$
\text{BCE}(x, y) = -[y \log(x) + (1-y) \log(1-x)]
$$
对于批量数据:
$$
\text{BCE}(x, y) = -\frac{1}{n}\sum_{i=1}^{n} [y_i \log(x_i) + (1-y_i) \log(1-x_i)]
$$
代码示例
1 | |
4. nn.NLLLoss - 负对数似然损失
数学公式
$$
\text{NLL}(x, y) = -x_{y}
$$
其中x是已经过LogSoftmax处理的输入。
代码示例
1 | |
5. nn.L1Loss - L1损失
数学公式
$$
\text{L1}(x, y) = \frac{1}{n} \sum_{i=1}^{n} |x_i - y_i|
$$
用途场景
- 回归任务:对异常值更鲁棒
- 图像处理:图像去噪、风格迁移
- 稀疏编码:L1正则化诱导稀疏解
6. nn.SmoothL1Loss - 平滑L1损失
数学公式
$$
\text{SmoothL1}(x, y) = \begin{cases}
\frac{1}{2}(x - y)^2 / \beta & \text{if } |x - y| < \beta \
|x - y| - \frac{1}{2}\beta & \text{otherwise}
\end{cases}
$$
其中β(beta)是阈值参数,默认为1.0。
7. nn.KLDivLoss - KL散度损失
数学公式
$$
\text{KL}(P || Q) = \sum_{i} P(i) \log\left(\frac{P(i)}{Q(i)}\right)
$$
用途场景
- 知识蒸馏:教师模型向学生模型传递知识
- 变分自编码器(VAE):正则化隐空间分布
- 强化学习:策略优化(PPO、SAC)
8. nn.MarginRankingLoss - 边际排序损失
数学公式
$$
\text{loss}(x_1, x_2, y) = \max(0, -y \cdot (x_1 - x_2) + \text{margin})
$$
其中y ∈ {-1, 1}表示排序关系。
9. nn.TripletMarginLoss - 三元组边际损失
数学公式
$$
\text{loss}(a, p, n) = \max(0, d(a, p) - d(a, n) + \text{margin})
$$
用于人脸识别、行人重识别、图像检索等。
10. nn.CosineEmbeddingLoss - 余弦嵌入损失
数学公式
$$
\text{loss}(x_1, x_2, y) = \begin{cases}
1 - \cos(x_1, x_2) & \text{if } y = 1 \
\max(0, \cos(x_1, x_2) - \text{margin}) & \text{if } y = -1
\end{cases}
$$
11. nn.CTCLoss - CTC损失
数学公式
$$
\text{CTC}(x, y) = -\log P(y|x) = -\log \sum_{\pi \in \mathcal{B}^{-1}(y)} P(\pi|x)
$$
用于语音识别、OCR等序列任务。
12. nn.HingeEmbeddingLoss - 铰链嵌入损失
数学公式
$$
\text{loss}(x, y) = \begin{cases}
x & \text{if } y = 1 \
\max(0, \text{margin} - x) & \text{if } y = -1
\end{cases}
$$
总结
损失函数选择指南
| 任务类型 | 推荐损失函数 | 备注 |
|---|---|---|
| 二分类 | BCEWithLogitsLoss | 数值稳定,包含Sigmoid |
| 多分类 | CrossEntropyLoss | 包含Softmax |
| 回归 | MSELoss / L1Loss | MSE对异常值敏感 |
| 边界框回归 | SmoothL1Loss | 目标检测标配 |
| 知识蒸馏 | KLDivLoss | 需要温度参数 |
| 度量学习 | TripletMarginLoss | 人脸识别、ReID |
| 语音识别 | CTCLoss | 无需对齐 |
| 排序学习 | MarginRankingLoss | 推荐系统 |
最佳实践
- 数值稳定性:优先使用包含内置激活函数的损失(如BCEWithLogitsLoss)
- 类别不平衡:使用class_weight或pos_weight参数
- 梯度分析:理解损失函数的梯度特性有助于选择合适的学习率
- 组合损失:复杂任务可以组合多种损失函数
1 | |
希望本文能帮助你深入理解PyTorch损失函数,在实际项目中做出正确的选择!


