在深度学习领域,有时我们需要一个模型同时完成多个相关的分类任务,例如,给定一幅图像,同时预测其艺术家、流派和风格。这被称为多任务分类。构建此类模型时,通常有两种策略:
在实践中,直接修改预训练模型(如ResNet18)的分类器可能不如预期。例如,简单地为ResNetForImageClassification实例添加classifier_artist、classifier_style、classifier_genre等属性,并不能自动将其集成到模型的forward方法中。torchinfo的输出也印证了这一点,模型的主体仍然是其原有的ResNetModel和Sequential (classifier),并未包含新定义的分类器。这通常意味着需要继承并重写模型的forward方法,或者正确地替换原有的分类头。
当自定义PyTorch模型时,我们拥有更大的灵活性来设计多任务架构。然而,这也引入了新的挑战,尤其是在处理不同层之间的数据维度匹配问题上。
构建自定义的WikiartModel用于多任务分类时,我们定义了共享的卷积层用于特征提取,并为艺术家、流派和风格三个任务分别设置了独立的全连接层分支。模型定义如下:
import torch import torch.nn as nn import torch.nn.functional as F class WikiartModel(nn.Module): def __init__(self, num_artists, num_genres, num_styles): super(WikiartModel, self).__init__() # Shared Convolutional Layers self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.pool = nn.MaxPool2d(2, 2) # Artist classification branch (Incorrect input size) self.fc_artist1 = nn.Linear(256 * 16 * 16, 512) # Potentially incorrect self.fc_artist2 = nn.Linear(512, num_artists) # Genre classification branch (Incorrect input size) self.fc_genre1 = nn.Linear(256 * 16 * 16, 512) # Potentially incorrect self.fc_genre2 = nn.Linear(512, num_genres) # Style classification branch (Incorrect input size) self.fc_style1 = nn.Linear(256 * 16 * 16, 512) # Potentially incorrect self.fc_style2 = nn.Linear(512, num_styles) def forward(self, x): # Shared convolutional layers x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = self.pool(F.relu(self.conv3(x))) x = x.view(-1, 256 * 16 * 16) # Potentially incorrect flattening # Artist classification branch artists_out = F.relu(self.fc_artist1(x)) artists_out = self.fc_artist2(artists_out) # Genre classification branch genre_out = F.relu(self.fc_genre1(x)) genre_out = self.fc_genre2(genre_out) # Style classification branch style_out = F.relu(self.fc_style1(x)) style_out = self.fc_style2(style_out) return artists_out, genre_out, style_out # num_artists, num_genres, num_styles are defined externally
在使用torchinfo检查模型结构时,我们发现一个关键问题:模型的输入批次大小为32(例如[32, 3, 224, 224]),但其内部全连接层(如fc_artist1)的输入批次大小却变成了98,导致最终输出的批次大小也为98。这直接引发了训练循环中计算损失时的ValueError: Expected input batch_size (98) to match target batch_size (32).错误。
问题根源分析:
这个批次大小不一致的根本原因在于卷积层输出特征图的尺寸计算错误,以及随后对特征图进行展平(flatten)操作时,全连接层期望的输入维度与实际不符。
让我们逐步分析数据流:
因此,在进入全连接层之前,特征图的实际尺寸是 [32, 256, 28, 28]。
问题出在这一行:x = x.view(-1, 256 * 16 * 16)。 当x的实际形状是[32, 256, 28, 28]时,总元素数量为 32 * 256 * 28 * 28 = 6422528。 而256 * 16 * 16 = 65536。 当使用x.view(-1, 65536)时,PyTorch会尝试将总元素数量除以65536来推断-1对应的维度: 6422528 / 65536 = 98。 所以,x被错误地展平为了[98, 65536],导致批次大小从32变成了98。
要解决这个问题,我们需要确保全连接层的输入维度与卷积层输出的实际展平尺寸相匹配,并且批次大小在展平过程中保持不变。
如上分析,经过三次卷积和三次最大池化操作后,对于 224x224 的输入图像,最终的特征图尺寸是 [Batch_Size, 256, 28, 28]。因此,展平后的特征向量长度应该是 256 * 28 * 28。
在将卷积层的输出传递给全连接层之前,需要将其展平为二维张量 [Batch_Size, Features]。为了确保批次大小不变,应该使用 x.view(x.size(0), -1)。这里的 x.size(0) 会保留原始的批次大小(例如32),而 -1 会自动计算剩余维度的乘积,将其展平为单个特征向量。
对于 [32, 256, 28, 28] 的张量,x.view(x.size(0), -1) 会将其展平为 [32, 256 * 28 * 28],即 [32, 200704]。
基于正确的展平尺寸,所有连接到卷积层输出的全连接层(fc_artist1, fc_genre1, fc_style1)的 in_features 参数都应该修改为 256 * 28 * 28。
# 将 nn.Linear(256 * 16 * 16, 512) # 修正为 nn.Linear(256 * 28 * 28, 512) # 256 * 28 * 28 = 200704
根据上述修正,WikiartModel的定义应更新如下:
import torch import torch.nn as nn import torch.nn.functional as F class WikiartModel(nn.Module): def __init__(self, num_artists, num_genres, num_styles): super(WikiartModel, self).__init__() # Shared Convolutional Layers self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.pool = nn.MaxPool2d(2, 2) # 计算卷积层最终输出的特征图尺寸,用于全连接层 # 对于224x224输入,经过三次conv+pool后,尺寸变为 28x28 self.final_feature_map_size = 28 self.flattened_features = 256 * self.final_feature_map_size * self.final_feature_map_size # 256 * 28 * 28 = 200704 # Artist classification branch self.fc_artist1 = nn.Linear(self.flattened_features, 512) self.fc_artist2 = nn.Linear(512, num_artists) # Genre classification branch self.fc_genre1 = nn.Linear(self.flattened_features, 512) self.fc_genre2 = nn.Linear(512, num_genres) # Style classification branch self.fc_style1 = nn.Linear(self.flattened_features, 512) self.fc_style2 = nn.Linear(512, num_styles) def forward(self, x): # Shared convolutional layers x = self.pool(F.relu(self.conv1(x))) # Output: [Batch_Size, 64, 112, 112] x = self.pool(F.relu(self.conv2(x))) # Output: [Batch_Size, 128, 56, 56] x = self.pool(F.relu(self.conv3(x))) # Output: [Batch_Size, 256, 28, 28] # Correct flattening: preserve batch size, flatten remaining dimensions x = x.view(x.size(0), -1) # Output: [Batch_Size, 256 * 28 * 28] = [Batch_Size, 200704] # Artist classification branch artists_out = F.relu(self.fc_artist1(x)) artists_out = self.fc_artist2(artists_out) # Genre classification branch genre_out = F.relu(self.fc_genre1(x)) genre_out = self.fc_genre2(genre_out) # Style classification branch style_out = F.relu(self.fc_style1(x)) style_out = self.fc_style2(style_out) return artists_out, genre_out, style_out # Example usage: num_artists = 129 num_genres = 11 num_styles = 27 model = WikiartModel(num_artists, num_genres, num_styles) # Now, if you pass a tensor of shape [32, 3, 224, 224] to the model, # the outputs will correctly have a batch size of 32. # e.g., artists_out.shape will be [32, 129]
批次大小不一致是PyTorch模型开发中常见的错误,尤其是在卷积层和全连接层之间进行维度转换时。解决此问题的关键在于:
通过遵循这些原则,可以有效地避免和解决PyTorch模型中因维度不匹配导致的批次大小不一致问题,确保模型能够顺利训练。
以上就是解决PyTorch多任务模型中批次大小不一致问题:卷积层输出展平与全连接层连接的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号