
本文针对pytorch模型训练中准确率异常低的问题进行深入探讨。核心原因在于模型评估阶段对正确预测数目的累加逻辑存在错误,以及对模型输入张量进行了不当展平。文章将详细解析这些常见陷阱,提供正确的代码修正方案,确保模型性能评估的准确性,帮助开发者有效诊断并解决训练过程中的此类问题。
在PyTorch模型训练过程中,开发者有时会遇到模型准确率始终处于极低水平,甚至低于随机猜测的情况,即使调整了批量大小、网络层数、迭代次数和学习率等超参数也无济于事。这种现象往往令人困惑,因为模型结构和数据加载看似正常。实际上,这通常不是模型本身无法学习,而是模型评估逻辑或数据预处理阶段存在缺陷,导致模型性能被错误地衡量。
通过对提供的代码进行分析,导致模型准确率异常的核心问题之一在于测试阶段对正确预测样本数 n_correct 的累加方式不正确。
在模型测试循环中,计算每个批次的正确预测数目的原始代码如下:
# ... (在测试循环内部) n_correct = (predictions == labels).sum().item()
这行代码的问题在于,它在每次迭代时都会重新赋值 n_correct,而不是将其与之前批次的正确预测数累加。这意味着 n_correct 最终只会保存最后一个批次的正确预测数,而不是整个测试集上的总和。因此,最终计算出的准确率将是基于单个批次而非整个数据集的,从而导致结果极低且不准确。
修正方法: 要解决此问题,只需将 n_correct 的赋值操作改为累加操作,确保在循环外部初始化 n_correct,并在循环内部使用 += 进行累加:
# ... (在测试循环内部) n_correct += (predictions == labels).sum().item()
除了 n_correct 的累加错误外,代码中还存在一个更基础但同样关键的潜在问题,即对模型输入 inputs 的不当 torch.flatten 操作。
在训练循环和测试循环中,都出现了以下代码:
# ... (在训练或测试循环内部) inputs = torch.flatten(inputs)
假设 input_size 被定义为5,且 DataLoader 提供的 inputs 形状为 (batch_size, 5)。对于 nn.Linear(input_size, hidden_size) 这样的全连接层,它期望的输入是 (batch_size, input_size)。如果对 inputs 进行 torch.flatten 操作,其形状将变为 (batch_size * 5)。这将导致 nn.Linear 层接收到的特征维度与 input_size 不匹配,或者在PyTorch内部进行不正确的自动调整,使模型接收到一个被展平的、不再具有原始特征结构的张量,从而无法进行有效的学习。
修正方法: 应移除训练和测试循环中对 inputs 的 torch.flatten(inputs) 操作。模型期望的输入形状通常由 nn.Linear 层的 in_features 参数决定,即 (batch_size, input_size)。
# 移除此行: # inputs = torch.flatten(inputs)
关于 labels 的形状处理: 代码中对 labels 也进行了 torch.flatten(labels) 操作。 由于 SDSS 数据集类中 self.y_data 的定义是 xy[:, [0]],这会使得 labels 的原始形状为 (batch_size, 1)。 nn.CrossEntropyLoss 期望的 target 形状是 (batch_size)(对于多分类问题,直接是类别索引)。因此,将 (batch_size, 1) 展平为 (batch_size) 是正确的处理方式。更语义化的做法是使用 labels = labels.squeeze(1),它明确表示移除维度为1的单维度。
以下是模型训练和测试循环中经过修正的关键部分代码:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
# device config
device = torch.device('cpu') # 示例使用CPU,若有GPU可改为'cuda'
input_size = 5
hidden_size = 10
num_classes = 3
num_epochs = 100
batch_size = 10
learning_rate = 0.0001
class SDSS(Dataset):
def __init__(self):
xy = np.loadtxt('SDSS.csv', delimiter=',', dtype=np.float32, skiprows=0)
self.n_samples = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, 1:]) # size [n_samples, n_features]
self.y_data = torch.from_numpy(xy[:, [以上就是PyTorch模型训练准确率异常:常见评估逻辑错误与修正方法的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号