
本文深入探讨了pytorch中`crossentropyloss`常见的`runtimeerror: expected scalar type long but found float`错误。该错误通常源于目标标签(target)的数据类型不符合`crossentropyloss`的预期。我们将详细解析错误原因,并提供如何在训练循环中正确使用`crossentropyloss`,包括标签类型转换、输入顺序以及避免重复应用softmax等关键最佳实践,以确保模型训练的稳定性和准确性。
在深度学习的分类任务中,torch.nn.CrossEntropyLoss是一个非常常用的损失函数。它结合了LogSoftmax和负对数似然损失(NLLLoss),能够高效地处理多分类问题。然而,初学者在使用时常会遇到一个特定的运行时错误:RuntimeError: expected scalar type Long but found Float。这个错误明确指出,CrossEntropyLoss在处理其目标标签(target)时,期望的数据类型是torch.Long(即64位整数),但实际接收到的是torch.Float。
CrossEntropyLoss函数在PyTorch中通常接收两个主要参数:
关键点: CrossEntropyLoss内部会自行执行Softmax操作,因此,向其传入经过Softmax处理的概率值是不正确的,这可能导致数值不稳定或不准确的损失计算。
这个错误的核心在于target张量的数据类型不匹配。在提供的代码片段中,错误发生在以下这行:
loss = criterion(output, labels.float())
尽管labels张量在创建时已经被明确指定为long类型:
labels = Variable(torch.FloatTensor(10).uniform_(0, 120).long())
但在计算损失时,又通过.float()方法将其强制转换回了float类型。这就是导致CrossEntropyLoss抛出错误的原因。
修正方法: 只需移除对labels的.float()调用,确保target张量保持其long类型即可。
# 错误代码 # loss = criterion(output, labels.float()) # 正确代码 loss = criterion(output, labels)
除了上述直接的类型转换错误,在提供的train_one_epoch函数中,也存在一些与CrossEntropyLoss使用相关的常见误区。
在train_one_epoch函数内部,标签被错误地转换成了float类型:
labels = labels.to(device).float() # 错误:将标签转换为float类型
这会直接导致CrossEntropyLoss接收到float类型的标签,再次触发同样的RuntimeError。
修正方法: 确保标签在送入损失函数前是long类型。
labels = labels.to(device).long() # 正确:将标签转换为long类型
在train_one_epoch函数中,计算损失的行是:
loss = criterion(labels, torch.argmax(outputs, dim=1)) # 错误:参数顺序和类型不符
这里存在两个问题:
修正方法: 将模型的原始输出(logits)作为第一个参数,真实的long类型标签作为第二个参数。
在计算outputs时,代码中显式地应用了F.softmax:
outputs = F.softmax(model(inputs.float()), dim=1) # 错误:CrossEntropyLoss内部已包含Softmax
由于CrossEntropyLoss内部已经包含了Softmax操作,再次应用F.softmax会导致:
修正方法: 直接将模型的原始输出(logits)传递给CrossEntropyLoss。
综合以上修正,以下是train_one_epoch函数的一个优化版本,遵循了CrossEntropyLoss的最佳实践:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
# 假设 model, optimizer, dataloaders, device 已经定义
def train_one_epoch(model, optimizer, data_loader, device):
    model.train()
    running_loss = 0.0
    start_time = time.time()
    total = 0
    correct = 0
    # 确保 data_loader 是实际的 DataLoader 对象
    # 这里假设 dataloaders['train'] 是一个可迭代的 DataLoader
    current_data_loader = data_loader # 如果传入的是字符串'train',需要根据实际情况获取
    if isinstance(data_loader, str):
        current_data_loader = dataloaders[data_loader] # 假设 dataloaders 是一个全局字典
    for i, (inputs, labels) in enumerate(current_data_loader):
        inputs = inputs.to(device)
        # 核心修正:确保标签是long类型
        labels = labels.to(device).long() 
        optimizer.zero_grad()
        # 修正:直接使用模型的原始输出(logits),不应用Softmax
        # 假设 model(inputs.float()) 返回的是 logits
        logits = model(inputs.float()) 
        # 打印形状以调试
        # print("Inputs shape:", inputs.shape)
        # print("Logits shape:", logits.shape)
        # print("Labels shape:", labels.shape)
        # 修正:CrossEntropyLoss的正确使用方式是 (logits, target_indices)
        loss = criterion(logits, labels) 
        loss.backward()
        optimizer.step()
        # 计算准确率时,需要对logits应用argmax
        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        running_loss += loss.item()
        if i % 10 == 0:    # print every 10 batches
            batch_time = time.time()
            speed = (i+1)/(batch_time-start_time)
            print('[%5d] loss: %.3f, speed: %.2f, accuracy: %.2f %%' %
                  (i, running_loss, speed, accuracy))
            running_loss = 0.0
            total = 0
            correct = 0val_model函数在处理标签时使用了labels = labels.to(device).long(),这是正确的。同时,outputs = model(inputs.float()) 假设模型输出的是logits,然后用 torch.max(outputs.data, 1) 来获取预测类别,这也是标准做法。
唯一需要注意的是,model.val() 应该更正为 model.eval(),这会将模型设置为评估模式,禁用Dropout和BatchNorm等层,以确保评估结果的稳定性。
def val_model(model, data_loader, device): # 添加 device 参数
    model.eval() # 修正:使用 model.eval()
    start_time = time.time()
    total = 0
    correct = 0
    current_data_loader = data_loader
    if isinstance(data_loader, str):
        current_data_loader = dataloaders[data_loader]
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(current_data_loader):
            inputs = inputs.to(device)
            labels = labels.to(device).long() # 正确
            outputs = model(inputs.float()) # 假设 model 输出 logits
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            # 修正:(predicted == labels).sum() 返回一个标量,直接 .item() 即可
            correct += (predicted == labels).sum().item() 
        accuracy = 100 * correct / total
        print('Finished Testing')
        print('Testing accuracy: %.1f %%' %(accuracy))处理PyTorch中的CrossEntropyLoss时,请牢记以下关键点:
遵循这些原则,可以有效避免RuntimeError: expected scalar type Long but found Float以及其他与CrossEntropyLoss使用相关的常见问题,确保模型训练的顺利进行。
以上就是PyTorch CrossEntropyLoss中的数据类型错误解析与最佳实践的详细内容,更多请关注php中文网其它相关文章!
 
                        
                        每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
 
                Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号