面试官总问的交叉熵:从信息论到PyTorch实战,一次讲清分类任务为什么用它

张开发
2026/4/4 1:34:39 15 分钟阅读
面试官总问的交叉熵:从信息论到PyTorch实战,一次讲清分类任务为什么用它
交叉熵从信息论到PyTorch实战揭秘分类任务的核心损失函数在机器学习面试中当面试官问到为什么分类问题用交叉熵而不用均方误差(MSE)时大多数候选人会给出一个标准答案因为交叉熵在分类问题上效果更好。但这样的回答显然无法让面试官满意。真正理解交叉熵背后的原理并能在实际项目中做出明智选择是区分优秀工程师和普通工程师的关键。1. 信息论基础理解交叉熵的本质交叉熵的概念源于信息论要真正理解它我们需要从几个基本概念开始。1.1 信息量与熵信息量衡量的是一个事件发生所带来的惊讶程度。一个几乎确定会发生的事件概率接近1提供的信息量很少而一个不太可能发生的事件概率接近0则提供大量信息。数学上信息量定义为I(x) -log(P(x))其中P(x)是事件x发生的概率。这个对数通常以2为底单位是比特或以自然对数e为底单位是纳特。熵信息熵则是所有可能事件的信息量的期望值H(X) -Σ P(x)log(P(x))熵可以理解为系统的不确定性或混乱程度。一个公平的硬币抛掷50%正反面比一个有偏见的硬币90%正面具有更高的熵。1.2 KL散度与交叉熵KL散度Kullback-Leibler divergence衡量两个概率分布P和Q之间的差异D_KL(P||Q) Σ P(x)log(P(x)/Q(x))交叉熵则是KL散度与信息熵的组合H(P,Q) -Σ P(x)log(Q(x)) H(P) D_KL(P||Q)在机器学习中P是真实分布标签Q是模型预测的分布。由于H(P)是固定的最小化交叉熵等价于最小化KL散度也就是让预测分布Q尽可能接近真实分布P。提示在实际分类任务中我们通常使用one-hot编码表示类别标签此时真实分布P的熵H(P)0交叉熵直接等于KL散度。2. 为什么分类问题偏爱交叉熵与MSE的深入对比2.1 梯度特性分析交叉熵在分类任务中表现优异的关键在于它的梯度特性。让我们比较一下交叉熵和均方误差(MSE)在二分类问题中的梯度交叉熵损失梯度∂L/∂w (σ(wx) - y)xMSE损失梯度∂L/∂w (σ(wx) - y)σ(wx)(1-σ(wx))x可以看到MSE的梯度多了一个σ(1-σ)项。当预测值接近0或1时即分类比较确定时这个项会变得非常小导致梯度消失学习速度变慢。而交叉熵的梯度则直接正比于误差(σ-y)没有这个衰减因子。2.2 收敛速度实验让我们用PyTorch在MNIST数据集上实际比较两种损失函数的收敛速度import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms # 数据准备 transform transforms.Compose([transforms.ToTensor()]) train_set datasets.MNIST(root./data, trainTrue, downloadTrue, transformtransform) train_loader torch.utils.data.DataLoader(train_set, batch_size64, shuffleTrue) # 简单模型 model nn.Sequential( nn.Flatten(), nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10) ) # 交叉熵实验 ce_model model.copy() ce_criterion nn.CrossEntropyLoss() ce_optimizer optim.SGD(ce_model.parameters(), lr0.1) # MSE实验 mse_model model.copy() mse_criterion nn.MSELoss() mse_optimizer optim.SGD(mse_model.parameters(), lr0.1) # 训练循环 for epoch in range(5): for images, labels in train_loader: # 交叉熵训练 ce_optimizer.zero_grad() outputs ce_model(images) loss ce_criterion(outputs, labels) loss.backward() ce_optimizer.step() # MSE训练需要one-hot编码 mse_optimizer.zero_grad() outputs mse_model(images) one_hot torch.zeros_like(outputs) one_hot.scatter_(1, labels.unsqueeze(1), 1) loss mse_criterion(outputs, one_hot) loss.backward() mse_optimizer.step()实验结果表明使用交叉熵损失的模型在前几轮就能达到较高准确率而MSE损失的模型收敛明显更慢。这是因为在分类边界附近MSE的梯度会变得非常小导致参数更新缓慢。3. 交叉熵的PyTorch实战应用3.1 多分类任务中的实现在PyTorch中nn.CrossEntropyLoss实际上结合了softmax和交叉熵计算因此模型的最后一层不需要额外添加softmax激活。下面是一个完整的多分类实现示例class Classifier(nn.Module): def __init__(self): super().__init__() self.net nn.Sequential( nn.Conv2d(1, 32, 3, padding1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding1), nn.ReLU(), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(64*7*7, 128), nn.ReLU(), nn.Linear(128, 10) ) def forward(self, x): return self.net(x) model Classifier() criterion nn.CrossEntropyLoss() optimizer optim.Adam(model.parameters()) # 训练循环 for epoch in range(10): for images, labels in train_loader: optimizer.zero_grad() outputs model(images) loss criterion(outputs, labels) loss.backward() optimizer.step()3.2 处理类别不平衡问题当数据集中各类别样本数量不均衡时简单的交叉熵可能导致模型偏向多数类。PyTorch提供了加权交叉熵来解决这个问题# 假设我们有一个各类别样本数的列表class_counts class_weights 1. / torch.tensor(class_counts, dtypetorch.float) # 归一化权重 class_weights class_weights / class_weights.sum() criterion nn.CrossEntropyLoss(weightclass_weights)另一种方法是使用标签平滑(label smoothing)它可以防止模型对预测结果过于自信criterion nn.CrossEntropyLoss(label_smoothing0.1)4. 高级话题交叉熵的变体与应用4.1 二分类与多分类的统一虽然我们通常分开讨论二分类使用sigmoid和多分类使用softmax问题但实际上它们可以统一看待。二分类交叉熵可以看作是多分类交叉熵在类别数为2时的特例# 二分类实现方式1 criterion nn.BCELoss() # 需要模型输出经过sigmoid # 二分类实现方式2更推荐 criterion nn.BCEWithLogitsLoss() # 内置sigmoid数值更稳定 # 多分类实现 criterion nn.CrossEntropyLoss() # 内置softmax4.2 交叉熵在自监督学习中的应用交叉熵不仅在监督学习中大放异彩在自监督学习中也扮演重要角色。例如在对比学习中InfoNCE损失本质上也是一种交叉熵# 简化的对比学习损失实现 def info_nce_loss(features, temperature0.1): batch_size features.size(0) # 计算相似度矩阵 sim_matrix torch.matmul(features, features.T) / temperature # 对角线是正样本对 labels torch.arange(batch_size).to(features.device) # 计算交叉熵损失 return nn.CrossEntropyLoss()(sim_matrix, labels)4.3 交叉熵与模型校准现代神经网络常常会过度自信即对预测结果赋予过高的概率。测量和改善模型校准度的一个常用指标是预期校准误差(ECE)它与交叉熵密切相关def expected_calibration_error(logits, labels, n_bins10): probabilities torch.softmax(logits, dim1) confidences, predictions torch.max(probabilities, 1) accuracies predictions.eq(labels) bin_boundaries torch.linspace(0, 1, n_bins 1) bin_lowers bin_boundaries[:-1] bin_uppers bin_boundaries[1:] ece 0 for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): in_bin confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item()) prop_in_bin in_bin.float().mean() if prop_in_bin 0: accuracy_in_bin accuracies[in_bin].float().mean() avg_confidence_in_bin confidences[in_bin].mean() ece torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin return ece在实际项目中我们常常会记录交叉熵损失和ECE两个指标交叉熵衡量模型的预测能力ECE则衡量预测概率的可靠性。

更多文章