pytorch中的cross_entropy函数

本文详细解析了PyTorch中的交叉熵函数cross_entropy的计算过程,包括输入参数input(模型预测概率)和target(真实类别标签)的处理。通过softmax将模型输出转换为概率分布,并对目标标签进行one-hot编码,简化计算公式。最终通过Python代码展示了从原始输出到交叉熵损失的计算步骤,并比较了函数和类的计算结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

        cross_entropy函数是pytorch中计算交叉熵的函数。根据源码分析,输入主要包括两部分,一个是input,是维度为(batch_size,class)的矩阵,class表示分类的数量,这个就表示模型输出的预测结果;另一个是target,是维度为(batch_size)的一维向量,表示每个样本的真实值。输出是交叉熵的值。nn中的CrossEntropyLoss类与此函数的作用相同。


计算过程

        交叉熵是常见的损失函数,之前的文章中已经详细介绍了交叉熵的公式由来(交叉熵详解),公式如下:

L=-[y*log\hat{y}+(1-y)*log(1-\hat{y})]

        如果用在多分类问题中当做损失函数的话,一般会这样写:

L=-\sum_{i=1}^{n}y\cdot log_{2}\hat{y}

        其中y是真实分类,是一个标签值;\hat{y}是模型预测结果,包含了属于每种标签的概率(此时这几个概率相加还不等于1)。在上面说了函数的输入分别是input和target,那么y就对应target这个向量,\hat{y}就对应input这个矩阵。但是现在就出现了两个问题:

         1、target就是一个标签值,无法与input直接进行运算,那么我们就对这个target先进行one-hot编码,使二者的维度相同。 比如target的值是[3],共有5个类,那么转换为one-hot编码之后就是:[0,0,0,1,0]。

        2、input中的预测值不能直接代表概率,而且这几个值相加不为1,这时就进行一个softmax操作,让模型的输出值满足上面两个条件,对输出结果的归一化公式如下:

\hat{y}=P(\hat{y}=i|x)=\frac{e^{input_{[i]}}}{\sum_{j=1}^{n}e^{input_{[j]}}}

        将\hat{y}带入到上面的损失函数公式中进行推导:

L=-\sum_{i=1}^{n}y\cdot log_{2}\hat{y}=-\sum_{i=1}^{n}y\cdot \frac{e^{input_{[i]}}}{\sum_{j=1}^{n}e^{input_{[j]}}}=-\sum_{i=1}^{n}y(input[i]-log_{2}\sum_{j=1}^{n}e^{input_{[j]}})

        在第一点中我们已经将y转换为了[0,0,0,1,0]这样的编码,可见式子中只有target那一项的损失值需要计算,其他与0相乘就都消掉了,所以式子中最外层那个连加号就可以去掉了。最后式子就简化为:

L=-input[target]+log_{2}\sum_{j=1}^{n}e^{input_{[j]}}


Python实现

        现在已经清楚了cross_entropy这个函数的运算过程,现在用Python来模拟一下,首先构造一个随机的input和target:

output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
print("output:")
print(output)
print("label:")
print(label)

        输出结果如下:

        然后模拟函数的运算过程:

first = -input[0][target[0]]
second = 0
res=0
for j in range(5):
    second += math.exp(input[0][j])
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

         输出结果为:

         然后分别调用cross_entropy函数和CrossEntropyLoss类计算loss的结果:

criterion = nn.CrossEntropyLoss()
loss = criterion(input, target)
loss2 = torch.nn.functional.cross_entropy(input=input,target=target)
print("cross_entropy函数计算loss的结果:")
print(loss)
print("CrossEntropyLoss类计算loss的结果:")
print(loss2)

        输出结果为:

<think>嗯,用户问的是cross_entropy函数有没有sigmoid。首先我需要确定用户指的是哪个cross_entropy函数,因为在PyTorch和TensorFlow中可能有不同的实现。假设用户指的是PyTorch中的nn.CrossEntropyLoss,因为这个问题比较常见。 PyTorchCrossEntropyLoss通常用于多分类问题,它结合了LogSoftmax和NLLLoss。也就是说,这个函数内部已经包含了softmax处理,所以在使用的时候不需要在模型的输出层再加softmax或者sigmoid。那sigmoid通常是用在二分类问题中的,这时候应该用BCEWithLogitsLoss,它内部包含了sigmoid和二元交叉熵的计算。所以用户可能混淆了这两个损失函数的应用场景。 接下来要解释交叉熵函数是否包含sigmoid激活。比如,CrossEntropyLoss不包含sigmoid,而是用softmax处理多分类。而BCEWithLogitsLoss包含sigmoid,处理二分类或多标签问题。所以需要明确用户的问题场景,如果是多分类,不需要sigmoid,如果是二分类或多标签,应该用BCEWithLogitsLoss,这时候内部有sigmoid。 另外,用户可能在使用CrossEntropyLoss时错误地添加了sigmoid,导致数值不稳定或者错误的结果。这时候需要指出正确的做法,即根据问题类型选择合适的损失函数,并避免重复应用激活函数。 最后,需要给出示例代码说明正确用法,比如在分类任务中使用CrossEntropyLoss时,模型的最后一层不需要激活函数,直接输出logits。而使用BCEWithLogitsLoss时同样不需要额外sigmoid。这样可以帮助用户理解如何正确应用这些函数。 总结起来,回答需要分情况讨论:CrossEntropyLoss不含sigmoid,适用于多分类;BCEWithLogitsLoss含sigmoid,适用于二分类或多标签。同时给出代码示例和常见错误,帮助用户正确使用。</think><< `cross_entropy`函数是否包含`sigmoid`取决于具体实现和框架的设计。以下是两种常见情况的分析: ### 1. **PyTorch的`nn.CrossEntropyLoss`** - **不包含Sigmoid**:PyTorch的`nn.CrossEntropyLoss`是为多分类任务设计的,其内部隐含了对输入的`LogSoftmax`操作,因此**不需要**在模型输出层显式添加`sigmoid`或`softmax`。 - 如果模型输出的是原始分数(logits),直接使用该损失函数即可。 - 示例代码: ```python import torch import torch.nn as nn # 模型输出原始分数(logits),未经过sigmoid或softmax logits = torch.randn(3, 5) # 3个样本,5个类别 labels = torch.tensor([1, 0, 4]) # 真实标签 criterion = nn.CrossEntropyLoss() loss = criterion(logits, labels) # 正确用法 ``` ### 2. **二分类任务中的交叉熵** - **需要Sigmoid**:对于二分类任务(如输出概率为0或1),通常使用`nn.BCEWithLogitsLoss`(二元交叉熵损失),该函数内部包含`sigmoid`操作。此时模型输出应为原始分数(logits)。 - 如果手动实现交叉熵损失,需显式调用`sigmoid`: ```python # 手动实现二分类交叉熵 outputs = torch.randn(3) # 3个样本的logits labels = torch.tensor([0, 1, 1], dtype=torch.float) sigmoid_outputs = torch.sigmoid(outputs) loss = -torch.mean(labels * torch.log(sigmoid_outputs) + (1 - labels) * torch.log(1 - sigmoid_outputs)) ``` ### 常见错误 - **错误用法**:对`nn.CrossEntropyLoss`的输入显式调用`softmax`或`sigmoid`,会导致数值不稳定或结果错误。 - **正确实践**:根据任务类型选择损失函数: - 多分类 → `CrossEntropyLoss`(无需`sigmoid`)。 - 二分类或多标签分类 → `BCEWithLogitsLoss`(隐含`sigmoid`)。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值