
在pytorch中构建和训练cnn时,开发者经常会遇到各种形状(shape)或维度(dimension)不匹配的错误。这些错误通常发生在数据从卷积层过渡到全连接层时,或者在计算损失时。理解这些错误的根源并掌握正确的调试方法对于成功训练深度学习模型至关重要。
根据提供的代码和错误描述,主要存在以下几个维度不匹配问题:
全连接层输入维度计算错误: 卷积层和池化层处理图像后,特征图的尺寸会发生变化。在将特征图展平(flatten)并输入到全连接层(nn.Linear)时,全连接层期望的输入特征数量必须与展平后的实际特征数量完全匹配。原始代码中 self.fc = nn.Linear(16 * 64 * 64, num_classes) 这一行,以及 X = X.view(-1, 16 * 64 * 64) 展平操作,可能错误地估计了经过多次池化后的特征图尺寸。
展平操作不当: 使用 X.view(-1, C*H*W) 进行展平时,如果 C*H*W 计算错误,会导致展平后的张量形状与全连接层期望的输入不符。更稳健的做法是使用 X.view(X.size(0), -1),让PyTorch自动计算除批次大小外的其他维度,从而避免手动计算错误。
损失函数目标张量形状: nn.CrossEntropyLoss 期望的输入是模型输出的原始对数几率(logits)张量 (N, C) 和目标标签的类别索引张量 (N),其中 N 是批次大小,C 是类别数量。原始代码中使用 labels.squeeze().long() 可能会在某些情况下导致标签张量形状不正确,尤其当 labels 本身已经是 (N) 形状时,squeeze() 可能没有效果或产生意外结果。直接使用 labels.long() 通常更安全。
验证循环指标计算错误: 在验证阶段,correct_val 和 total_val 这两个变量没有在验证循环内部正确更新,导致验证准确率始终为零或出现除以零的错误。
针对上述问题,我们将对模型架构、损失函数计算和训练/验证循环进行以下修正。
核心在于修正 ConvNet 类中全连接层的输入尺寸和展平操作。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
class ConvNet(nn.Module):
def __init__(self, num_classes=4):
super(ConvNet, self).__init__()
# 卷积层
self.conv1 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1)
# 最大池化层
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# 全连接层:修正输入尺寸为 16 * 32 * 32
self.fc = nn.Linear(16 * 32 * 32, num_classes)
def forward(self, X):
# 卷积层、ReLU激活和最大池化
X = F.relu(self.conv1(X))
X = self.pool(X)
X = F.relu(self.conv2(X))
X = self.pool(X)
X = F.relu(self.conv3(X))
X = self.pool(X)
# 展平输出,保持批次大小不变,让PyTorch自动计算其他维度
X = X.view(X.size(0), -1)
# 全连接层
X = self.fc(X)
return X关键改动点:
在计算损失时,确保标签张量的形状符合 nn.CrossEntropyLoss 的要求。
# 训练循环中
# ...
# Forward pass
outputs = model(images)
# 直接使用 labels.long(),确保标签是长整型
loss = criterion(outputs, labels.long())
# ...
# 验证循环中
# ...
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
# 直接使用 labels.long()
loss = criterion(outputs, labels.long())
total_val_loss += loss.item()
# ...关键改动点:
确保在验证阶段正确地更新 correct_val 和 total_val,以便准确计算验证准确率。
# ... (其他代码保持不变,如 SceneDataset, get_dataloaders 等)
# 初始化你的网络
model = ConvNet()
# 定义你的损失函数
criterion = nn.CrossEntropyLoss()
# 初始化优化器
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=5e-04)
# Placeholder for best validation accuracy
best_val_accuracy = 0.0
# Placeholder for the best model state
best_model_state = None
# Placeholder for training and validation statistics
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []
# 开始训练
for epoch in range(max_epoch):
model.train() # 设置模型为训练模式
total_train_loss = 0.0
correct_train = 0
total_train = 0
for images, labels in train_loader:
optimizer.zero_grad()
# 前向传播
outputs = model(images)
# 计算损失
loss = criterion(outputs, labels.long())
# 反向传播和优化
loss.backward()
optimizer.step()
total_train_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total_train += labels.size(0)
correct_train += (predicted == labels).sum().item() # 修正:直接比较 predicted 和 labels
# 计算训练准确率和损失
train_accuracy = correct_train / total_train
train_losses.append(total_train_loss / len(train_loader))
train_accuracies.append(train_accuracy)
# 验证
model.eval() # 设置模型为评估模式
total_val_loss = 0.0
correct_val = 0 # 在每个epoch开始时重置
total_val = 0 # 在每个epoch开始时重置
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
loss = criterion(outputs, labels.long())
total_val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0) # 修正:更新 total_val
correct_val += (predicted == labels).sum().item() # 修正:更新 correct_val
# 计算验证准确率和损失
val_accuracy = correct_val / total_val if total_val > 0 else 0.0 # 避免除以零
val_losses.append(total_val_loss / len(val_loader))
val_accuracies.append(val_accuracy)
print(f"Epoch {epoch+1}/{max_epoch}, "
f"Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accuracies[-1]:.4f}, "
f"Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_accuracies[-1]:.4f}")
# 根据验证准确率保存最佳模型
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
best_model_state = model.state_dict()
# 保存最佳模型状态到文件
best_model_path = "best_cnn_sgd.pth"
if best_model_state:
torch.save(best_model_state, best_model_path)
print(f"Best model saved to {best_model_path} with validation accuracy: {best_val_accuracy:.4f}")
else:
print("No best model saved (validation accuracy did not improve).")
# 绘制损失图
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss vs. Epoch')
plt.legend()
plt.show()
# 绘制准确率图
plt.figure(figsize=(10, 5))
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy vs. Epoch')
plt.legend()
plt.show()关键改动点:
解决PyTorch CNN训练中的维度不匹配问题,特别是与全连接层输入尺寸、展平操作和损失函数目标形状相关的错误,是模型开发中的常见挑战。通过精确计算特征图尺寸、采用健壮的展平方法、确保损失函数输入正确,并细致地管理训练和验证循环中的指标,可以有效避免这些错误,从而构建稳定且高效的深度学习模型。本文提供的修正和建议旨在帮助开发者更好地理解和解决这些问题,为PyTorch模型的成功训练奠定基础。
以上就是PyTorch CNN训练中批次大小不匹配与维度错误:诊断与解决方案的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号