在pytorch中进行多标签图像分类时,我们可能需要构建自定义模型来同时预测多个属性(例如,艺术家的作品、流派和风格)。一个常见的问题是,模型的输入批量大小与输出批量大小不匹配,这通常在计算损失时导致valueerror: expected input batch_size (98) to match target batch_size (32).。
例如,当我们期望输入图像批次为 [32, 3, 224, 224](批量大小为32),但模型输出的预测结果却显示为 [98, N_classes](批量大小为98),这明显表明在模型内部的某个环节,批量维度发生了意外的改变。通过torchinfo工具查看模型摘要,可以清晰地看到这种不一致:
Layer (type (var_name)) Input Shape Output Shape ================================================================================ WikiartModel (WikiartModel) [32, 3, 224, 224] [98, 129] ├─Conv2d (conv1) [32, 3, 224, 224] [32, 64, 224, 224] ├─MaxPool2d (pool) [32, 64, 224, 224] [32, 64, 112, 112] ├─Conv2d (conv2) [32, 64, 112, 112] [32, 128, 112, 112] ├─MaxPool2d (pool) [32, 128, 112, 112] [32, 128, 56, 56] ├─Conv2d (conv3) [32, 128, 56, 56] [32, 256, 56, 56] ├─MaxPool2d (pool) [32, 256, 56, 56] [32, 256, 28, 28] ├─Linear (fc_artist1) [98, 65536] [98, 512] ...
从上述摘要中可以看出,在经过一系列卷积和池化层后,张量的批量大小仍然保持为32,但在进入第一个全连接层(fc_artist1)时,输入形状的批量大小突然变成了98,这正是问题的根源。
这种批量大小的意外变化,几乎总是由于在将卷积层的输出展平(flatten)为全连接层的输入时,torch.Tensor.view() 方法使用不当造成的。
让我们分析一下 WikiartModel 中的数据流:
输入图像: [32, 3, 224, 224] (批量大小,通道,高度,宽度)
通过卷积和池化层:
到此为止,批量大小(32)是正确的,图像的特征图尺寸为 256 x 28 x 28。
展平操作: 原始代码中使用的展平操作是:
x = x.view(-1, 256 * 16 * 16)
这里的问题在于,256 * 16 * 16 (65536) 是一个固定的、错误的展平维度。模型在经过卷积层后,其特征图的实际空间维度是 28x28,而不是 16x16。
当 view(-1, K) 被调用时,PyTorch会尝试将张量重塑为 (N, K) 的形状,其中 N 是通过保持总元素数量不变来计算的。
这就是导致批量大小从32意外变为98的根本原因。这种不正确的展平操作使得模型内部的批量大小与输入数据的批量大小不一致,从而在后续的损失计算中引发错误。
要解决此问题,我们需要进行两处关键修正:
修正 forward 方法中的展平操作: 为了保持原始的批量大小并展平剩余的维度,我们应该使用 x.view(x.size(0), -1)。x.size(0) 明确地保留了原始的批量大小,而 -1 则让PyTorch自动计算剩余维度展平后的总大小。或者,更清晰地,可以使用 torch.flatten(x, 1),它会从第一个维度(即批量维度之后)开始展平。
将:
x = x.view(-1, 256 * 16 * 16)
修改为:
x = x.view(x.size(0), -1) # 或者 x = torch.flatten(x, 1)
修正全连接层输入维度: 由于现在我们正确地展平了张量,全连接层的 in_features 参数必须与展平后的实际维度匹配。经过 [32, 256, 28, 28] 的张量展平后,每个样本的特征维度是 256 * 28 * 28。
将所有 nn.Linear(256 * 16 * 16, 512) 修改为:
nn.Linear(256 * 28 * 28, 512)
为了代码的可读性和维护性,可以在 __init__ 中计算这个尺寸并存储,例如 self.flatten_size = 256 * 28 * 28。
以下是修正后的 WikiartModel 示例代码:
import torch import torch.nn as nn import torch.nn.functional as F class WikiartModel(nn.Module): def __init__(self, num_artists, num_genres, num_styles): super(WikiartModel, self).__init__() # Shared Convolutional Layers self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.pool = nn.MaxPool2d(2, 2) # 计算经过卷积和池化后特征图的最终空间维度 # 224 -> (pool) 112 -> (pool) 56 -> (pool) 28 self.final_spatial_dim = 28 self.flatten_features = 256 * self.final_spatial_dim * self.final_spatial_dim # 256 * 28 * 28 = 200704 # Artist classification branch self.fc_artist1 = nn.Linear(self.flatten_features, 512) self.fc_artist2 = nn.Linear(512, num_artists) # Genre classification branch self.fc_genre1 = nn.Linear(self.flatten_features, 512) self.fc_genre2 = nn.Linear(512, num_genres) # Style classification branch self.fc_style1 = nn.Linear(self.flatten_features, 512) self.fc_style2 = nn.Linear(512, num_styles) def forward(self, x): # Shared convolutional layers x = self.pool(F.relu(self.conv1(x))) # Output: [batch_size, 64, 112, 112] x = self.pool(F.relu(self.conv2(x))) # Output: [batch_size, 128, 56, 56] x = self.pool(F.relu(self.conv3(x))) # Output: [batch_size, 256, 28, 28] # Correct flattening: preserve batch size, flatten remaining dimensions x = x.view(x.size(0), -1) # Output: [batch_size, 256 * 28 * 28] = [batch_size, 200704] # Artist classification branch artists_out = F.relu(self.fc_artist1(x)) artists_out = self.fc_artist2(artists_out) # Genre classification branch genre_out = F.relu(self.fc_genre1(x)) genre_out = self.fc_genre2(genre_out) # Style classification branch style_out = F.relu(self.fc_style1(x)) style_out = self.fc_style2(style_out) return artists_out, genre_out, style_out # Set the number of classes for each task num_artists = 129 num_genres = 11 num_styles = 27 # Example usage (for demonstration) model = WikiartModel(num_artists, num_genres, num_styles) dummy_input = torch.randn(32, 3, 224, 224) # Batch size 32 artists_pred, genres_pred, styles_pred = model(dummy_input) print(f"Artist predictions shape: {artists_pred.shape}") # Expected: [32, 129] print(f"Genre predictions shape: {genres_pred.shape}") # Expected: [32, 11] print(f"Style predictions shape: {styles_pred.shape}") # Expected: [32, 27]
通过这些修正,模型的数据流将变得一致,并且批量大小将正确地从输入传递到输出,从而解决损失计算时的 ValueError。
# ... after conv layers self.flatten = nn.Flatten() # ... def forward(self, x): # ... conv layers x = self.flatten(x) # ...
批量大小不一致是PyTorch模型开发中一个常见的、但往往令人困惑的问题。它通常源于对 torch.Tensor.view() 等张量操作的误解,尤其是在将多维卷积输出展平为全连接层输入时。通过精确计算中间张量形状,并使用 x.view(x.size(0), -1) 或 torch.flatten(x, 1) 等正确方法进行展平,可以有效地避免此类问题。在模型开发过程中,持续利用 torchinfo 或手动打印形状进行调试,是确保模型数据流正确性和稳定性的关键。
以上就是PyTorch多标签图像分类:批量大小不一致问题的诊断与解决的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号