PyTorch CrossEntropyLoss中的数据类型错误解析与最佳实践

花韻仙語
发布: 2025-10-26 12:03:14
原创
664人浏览过

PyTorch CrossEntropyLoss中的数据类型错误解析与最佳实践

本文深入探讨了pytorch中`crossentropyloss`常见的`runtimeerror: expected scalar type long but found float`错误。该错误通常源于目标标签(target)的数据类型不符合`crossentropyloss`的预期。我们将详细解析错误原因,并提供如何在训练循环中正确使用`crossentropyloss`,包括标签类型转换、输入顺序以及避免重复应用softmax等关键最佳实践,以确保模型训练的稳定性和准确性。

深度学习的分类任务中,torch.nn.CrossEntropyLoss是一个非常常用的损失函数。它结合了LogSoftmax和负对数似然损失(NLLLoss),能够高效地处理多分类问题。然而,初学者在使用时常会遇到一个特定的运行时错误:RuntimeError: expected scalar type Long but found Float。这个错误明确指出,CrossEntropyLoss在处理其目标标签(target)时,期望的数据类型是torch.Long(即64位整数),但实际接收到的是torch.Float。

理解CrossEntropyLoss的工作原理

CrossEntropyLoss函数在PyTorch中通常接收两个主要参数:

  1. input (或 logits):这是模型的原始输出,通常是未经Softmax激活函数处理的“对数几率”(logits)。它的形状通常是 (N, C),其中 N 是批量大小,C 是类别数量。对于图像任务,如果模型输出是像素级别的分类(如U-Net),则形状可能是 (N, C, H, W)。
  2. target (或 labels):这是真实的类别标签。它应该包含每个样本的类别索引,其数据类型必须是torch.long(或torch.int64)。它的形状通常是 (N),对于像素级别的分类,形状可能是 (N, H, W)。target中的值应介于 0 到 C-1 之间,代表对应的类别索引。

关键点: CrossEntropyLoss内部会自行执行Softmax操作,因此,向其传入经过Softmax处理的概率值是不正确的,这可能导致数值不稳定或不准确的损失计算。

RuntimeError: expected scalar type Long but found Float 错误解析与修正

这个错误的核心在于target张量的数据类型不匹配。在提供的代码片段中,错误发生在以下这行:

loss = criterion(output, labels.float())
登录后复制

尽管labels张量在创建时已经被明确指定为long类型:

labels = Variable(torch.FloatTensor(10).uniform_(0, 120).long())
登录后复制

但在计算损失时,又通过.float()方法将其强制转换回了float类型。这就是导致CrossEntropyLoss抛出错误的原因。

修正方法: 只需移除对labels的.float()调用,确保target张量保持其long类型即可。

# 错误代码
# loss = criterion(output, labels.float())

# 正确代码
loss = criterion(output, labels)
登录后复制

训练循环中的常见误用及修正

除了上述直接的类型转换错误,在提供的train_one_epoch函数中,也存在一些与CrossEntropyLoss使用相关的常见误区。

1. 标签数据类型转换错误

在train_one_epoch函数内部,标签被错误地转换成了float类型:

labels = labels.to(device).float() # 错误:将标签转换为float类型
登录后复制

这会直接导致CrossEntropyLoss接收到float类型的标签,再次触发同样的RuntimeError。

修正方法: 确保标签在送入损失函数前是long类型。

labels = labels.to(device).long() # 正确:将标签转换为long类型
登录后复制

2. CrossEntropyLoss输入参数顺序和类型错误

在train_one_epoch函数中,计算损失的行是:

文心大模型
文心大模型

百度飞桨-文心大模型 ERNIE 3.0 文本理解与创作

文心大模型56
查看详情 文心大模型
loss = criterion(labels, torch.argmax(outputs, dim=1)) # 错误:参数顺序和类型不符
登录后复制

这里存在两个问题:

  • 参数顺序错误: criterion(即CrossEntropyLoss)期望的第一个参数是模型的输出(logits),第二个参数是真实标签(target)。这里却反了过来。
  • target参数类型错误: torch.argmax(outputs, dim=1) 已经是一个预测结果的类别索引,它不应该作为CrossEntropyLoss的target参数传入。target参数应是真实的、未经模型处理的类别标签。

修正方法: 将模型的原始输出(logits)作为第一个参数,真实的long类型标签作为第二个参数。

3. 预先应用Softmax的错误

在计算outputs时,代码中显式地应用了F.softmax:

outputs = F.softmax(model(inputs.float()), dim=1) # 错误:CrossEntropyLoss内部已包含Softmax
登录后复制

由于CrossEntropyLoss内部已经包含了Softmax操作,再次应用F.softmax会导致:

  • 冗余计算: 增加了不必要的计算开销。
  • 数值稳定性问题: 两次Softmax操作可能导致数值精度下降,尤其是在处理非常大或非常小的对数几率时。

修正方法: 直接将模型的原始输出(logits)传递给CrossEntropyLoss。

优化后的训练函数示例

综合以上修正,以下是train_one_epoch函数的一个优化版本,遵循了CrossEntropyLoss的最佳实践:

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

# 假设 model, optimizer, dataloaders, device 已经定义

def train_one_epoch(model, optimizer, data_loader, device):
    model.train()
    running_loss = 0.0
    start_time = time.time()
    total = 0
    correct = 0

    # 确保 data_loader 是实际的 DataLoader 对象
    # 这里假设 dataloaders['train'] 是一个可迭代的 DataLoader
    current_data_loader = data_loader # 如果传入的是字符串'train',需要根据实际情况获取
    if isinstance(data_loader, str):
        current_data_loader = dataloaders[data_loader] # 假设 dataloaders 是一个全局字典

    for i, (inputs, labels) in enumerate(current_data_loader):
        inputs = inputs.to(device)
        # 核心修正:确保标签是long类型
        labels = labels.to(device).long() 

        optimizer.zero_grad()

        # 修正:直接使用模型的原始输出(logits),不应用Softmax
        # 假设 model(inputs.float()) 返回的是 logits
        logits = model(inputs.float()) 

        # 打印形状以调试
        # print("Inputs shape:", inputs.shape)
        # print("Logits shape:", logits.shape)
        # print("Labels shape:", labels.shape)

        # 修正:CrossEntropyLoss的正确使用方式是 (logits, target_indices)
        loss = criterion(logits, labels) 

        loss.backward()
        optimizer.step()

        # 计算准确率时,需要对logits应用argmax
        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total

        running_loss += loss.item()
        if i % 10 == 0:    # print every 10 batches
            batch_time = time.time()
            speed = (i+1)/(batch_time-start_time)
            print('[%5d] loss: %.3f, speed: %.2f, accuracy: %.2f %%' %
                  (i, running_loss, speed, accuracy))
            running_loss = 0.0
            total = 0
            correct = 0
登录后复制

验证模型函数 (val_model) 的注意事项

val_model函数在处理标签时使用了labels = labels.to(device).long(),这是正确的。同时,outputs = model(inputs.float()) 假设模型输出的是logits,然后用 torch.max(outputs.data, 1) 来获取预测类别,这也是标准做法。

唯一需要注意的是,model.val() 应该更正为 model.eval(),这会将模型设置为评估模式,禁用Dropout和BatchNorm等层,以确保评估结果的稳定性。

def val_model(model, data_loader, device): # 添加 device 参数
    model.eval() # 修正:使用 model.eval()
    start_time = time.time()
    total = 0
    correct = 0

    current_data_loader = data_loader
    if isinstance(data_loader, str):
        current_data_loader = dataloaders[data_loader]

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(current_data_loader):
            inputs = inputs.to(device)
            labels = labels.to(device).long() # 正确

            outputs = model(inputs.float()) # 假设 model 输出 logits

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            # 修正:(predicted == labels).sum() 返回一个标量,直接 .item() 即可
            correct += (predicted == labels).sum().item() 
        accuracy = 100 * correct / total

        print('Finished Testing')
        print('Testing accuracy: %.1f %%' %(accuracy))
登录后复制

总结与最佳实践

处理PyTorch中的CrossEntropyLoss时,请牢记以下关键点:

  1. 目标标签的数据类型: CrossEntropyLoss的target参数必须是torch.long类型(即64位整数),且包含类别索引(从0到C-1)。
  2. 模型输出: CrossEntropyLoss的input参数应是模型的原始输出(logits),即未经Softmax激活函数处理的对数几率。
  3. 避免重复Softmax: 不要在将模型输出传递给CrossEntropyLoss之前手动应用F.softmax,因为CrossEntropyLoss内部已经包含了此操作。
  4. 参数顺序: CrossEntropyLoss的调用格式是 loss = criterion(logits, target_labels)。
  5. 评估模式: 在验证或测试模型时,务必使用model.eval()来设置模型为评估模式,并在torch.no_grad()上下文管理器中执行前向传播,以节省内存和计算。

遵循这些原则,可以有效避免RuntimeError: expected scalar type Long but found Float以及其他与CrossEntropyLoss使用相关的常见问题,确保模型训练的顺利进行。

以上就是PyTorch CrossEntropyLoss中的数据类型错误解析与最佳实践的详细内容,更多请关注php中文网其它相关文章!

最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习
PHP中文网抖音号
发现有趣的

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号