解决PyTorch模型中torch.cat操作的张量尺寸不匹配问题

心靈之曲
发布: 2025-11-14 13:47:20
原创
967人浏览过

解决PyTorch模型中torch.cat操作的张量尺寸不匹配问题

本教程旨在解决pytorch模型中常见的`runtimeerror: sizes of tensors must match except in dimension 1`错误,该错误通常发生在编码器-解码器架构(如hourglass网络)的`torch.cat`操作中。文章将详细分析导致空间维度和通道维度不匹配的原因,并提供一套系统的解决方案,包括调整输入图像尺寸、重新校准反卷积层输出通道,并提供一个修正后的模型架构示例,以确保张量操作的兼容性。

在PyTorch深度学习模型的开发过程中,尤其是在构建具有跳跃连接(skip connections)的编码器-解码器架构(如U-Net或Hourglass网络)时,经常会遇到张量尺寸不匹配的问题。其中一个常见的错误是RuntimeError: Sizes of tensors must match except in dimension 1,它通常发生在尝试使用torch.cat函数拼接不同空间尺寸的张量时。本教程将深入探讨这一问题,并提供详细的诊断与解决方案。

理解Hourglass网络与张量尺寸变化

Hourglass网络是一种经典的编码器-解码器结构,通过一系列的卷积(下采样)和反卷积(上采样)层来提取和恢复特征。其关键特点是跳跃连接,它将编码器路径中的特征图直接连接到解码器路径中对应的上采样特征图,以保留空间信息。

在nn.Conv2d(卷积)和nn.ConvTranspose2d(反卷积/转置卷积)操作中,张量的空间维度(高度和宽度)会根据以下公式发生变化:

  • nn.Conv2d (下采样): H_out = floor((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0]) + 1W_out = floor((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1]) + 1
  • nn.ConvTranspose2d (上采样): H_out = (H_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1W_out = (W_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1

这些公式表明,输入张量的尺寸、卷积核大小、步长、填充和输出填充都会影响输出张量的空间维度。在Hourglass网络中,多层下采样后,再进行多层上采样,如果输入尺寸选择不当,或者各层参数配置不精确,很容易导致在跳跃连接处,上采样路径的张量与跳跃连接的张量空间维度不一致。

诊断torch.cat张量尺寸不匹配错误

当出现RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 18 but got size 17 for tensor number 1 in the list.这样的错误时,意味着在尝试拼接两个张量时,除了通道维度(dimension 1)外,其他维度(通常是高度和宽度)不一致。

例如,在提供的代码中,错误发生在以下行:

up_5 = torch.cat([up_5, skip_5], 1)
登录后复制

根据错误信息,up_5和skip_5在高度或宽度上存在不匹配。假设up_5的尺寸为 [1, 64, 18, 12],而skip_5的尺寸为 [1, 4, 17, 11]。虽然通道数不同是允许的(torch.cat沿着指定维度拼接),但空间维度 (18, 12) 与 (17, 11) 的不一致导致了错误。

此外,即使空间维度匹配,还可能遇到第二个问题:通道维度不匹配。在torch.cat操作后,新的张量通道数是两个输入张量通道数之和。如果紧随其后的nn.BatchNorm2d层期望的输入通道数与实际拼接后的通道数不符,则会引发RuntimeError: running_mean should contain X elements not Y之类的错误。例如,如果up_5输出64通道,skip_5输出4通道,拼接后是68通道。但如果self.u_bn_5 = nn.BatchNorm2d(64),那么BatchNorm2d将期望64通道,从而导致错误。

解决方案

解决这类问题需要从两个主要方面入手:调整输入数据尺寸和修正模型架构中的通道配置。

商汤商量
商汤商量

商汤科技研发的AI对话工具,商量商量,都能解决。

商汤商量 36
查看详情 商汤商量

1. 调整输入图像尺寸以适应网络结构

Hourglass网络通常涉及多级下采样,例如6个nn.Conv2d层,每个层将空间维度减半。这意味着输入图像的尺寸应该能够被 2^6 = 64 整除,以避免在下采样过程中出现小数或不规则的尺寸变化,从而导致上采样时无法精确恢复到原始或匹配的尺寸。

根据提供的模型架构,为了使下采样和上采样路径的空间维度能够正确对齐,建议将输入张量 z 的尺寸调整为 [1, 2, 512, 320]。这里的512和320都是64的倍数,能够更好地适应模型的下采样操作。

# 原始代码
# z = torch.Tensor(np.mgrid[:542, :347]).unsqueeze(0).to(device) / 512
# x = PILImage.open(original_image_path)
# x = to_tensor(x).unsqueeze(0)
# x, mask = pixel_thanos(x, 0.5) # mask size is also [1, 3, 542, 347]

# 修正后的输入尺寸示例
# 确保原始图像和mask的尺寸也与模型兼容
input_height, input_width = 512, 320 # 假设这是你的目标输入尺寸
z = torch.Tensor(np.mgrid[:input_height, :input_width]).unsqueeze(0).to(device) / max(input_height, input_width)

# 如果原始图像尺寸不符,需要进行预处理
x = PILImage.open(original_image_path).resize((input_width, input_height)) # 调整图像尺寸
x = to_tensor(x).unsqueeze(0)
x, mask = pixel_thanos(x, 0.5, target_height=input_height, target_width=input_width) # 确保mask尺寸也匹配
mask = mask[:, :3, :, :].to(device).float()
x = x.to(device)

# 修改 pixel_thanos 函数以支持指定尺寸
def pixel_thanos(img, p=0.5, target_height=542, target_width=347):
    assert p > 0 and p < 1, 'The probability value should lie in (0, 1)'
    mask = torch.rand(1, 3, target_height, target_width) # 使用目标尺寸
    img[mask < p,] = 0
    mask = mask > p
    mask = mask.repeat(1, 3, 1, 1)
    return img, mask
登录后复制

2. 重新校准解码器通道维度

在Hourglass网络的解码器部分,nn.ConvTranspose2d层的输出通道数需要与跳跃连接的通道数以及后续的nn.BatchNorm2d层期望的通道数保持一致。具体来说,当执行torch.cat([up_tensor, skip_tensor], 1)时,up_tensor的通道数加上skip_tensor的通道数,必须等于紧随其后的nn.BatchNorm2d层所期望的输入通道数。

例如,如果skip_5有4个通道,而nn.BatchNorm2d(128)期望128个通道,那么u_deconv_5的输出通道数就应该是 128 - 4 = 124。

以下是根据上述原则修正后的Hourglass模型架构:

import torch
import torch.nn as nn
import numpy as np

class Hourglass(nn.Module):
    def __init__(self):
        super(Hourglass, self).__init__()

        self.leaky_relu = nn.LeakyReLU()

        # Encoder Path (Downsampling)
        self.d_conv_1 = nn.Conv2d(2, 8, 5, stride=2, padding=2)
        self.d_bn_1 = nn.BatchNorm2d(8)

        self.d_conv_2 = nn.Conv2d(8, 16, 5, stride=2, padding=2)
        self.d_bn_2 = nn.BatchNorm2d(16)

        self.d_conv_3 = nn.Conv2d(16, 32, 5, stride=2, padding=2)
        self.d_bn_3 = nn.BatchNorm2d(32)
        self.s_conv_3 = nn.Conv2d(32, 4, 5, stride=1, padding=2) # Skip connection 3

        self.d_conv_4 = nn.Conv2d(32, 64, 5, stride=2, padding=2)
        self.d_bn_4 = nn.BatchNorm2d(64)
        self.s_conv_4 = nn.Conv2d(64, 4, 5, stride=1, padding=2) # Skip connection 4

        self.d_conv_5 = nn.Conv2d(64, 128, 5, stride=2, padding=2)
        self.d_bn_5 = nn.BatchNorm2d(128)
        self.s_conv_5 = nn.Conv2d(128, 4, 5, stride=1, padding=2) # Skip connection 5

        self.d_conv_6 = nn.Conv2d(128, 256, 5, stride=2, padding=2)
        self.d_bn_6 = nn.BatchNorm2d(256)

        # Decoder Path (Upsampling) - Adjusted output channels for concatenation
        # u_deconv_5 output + skip_5 channels = u_bn_5 input
        # 124 + 4 = 128
        self.u_deconv_5 = nn.ConvTranspose2d(256, 124, 4, stride=2, padding=1)
        self.u_bn_5 = nn.BatchNorm2d(128)

        # u_deconv_4 output + skip_4 channels = u_bn_4 input
        # 60 + 4 = 64
        self.u_deconv_4 = nn.ConvTranspose2d(128, 60, 4, stride=2, padding=1) # Input channels should match previous layer's output
        self.u_bn_4 = nn.BatchNorm2d(64)

        # u_deconv_3 output + skip_3 channels = u_bn_3 input
        # 28 + 4 = 32
        self.u_deconv_3 = nn.ConvTranspose2d(64, 28, 4, stride=2, padding=1)
        self.u_bn_3 = nn.BatchNorm2d(32)

        # Remaining upsampling layers (no skip connections here, just match previous output to next input)
        self.u_deconv_2 = nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1)
        self.u_bn_2 = nn.BatchNorm2d(16)

        self.u_deconv_1 = nn.ConvTranspose2d(16, 8, 4, stride=2, padding=1)
        self.u_bn_1 = nn.BatchNorm2d(8)

        self.out_deconv = nn.ConvTranspose2d(8, 4, 4, stride=2, padding=1)
        self.out_bn = nn.BatchNorm2d(4)

    def forward(self, noise):
        # Print tensor sizes for debugging
        # print("input:", noise.size())

        down_1 = self.d_conv_1(noise)
        down_1 = self.d_bn_1(down_1)
        down_1 = self.leaky_relu(down_1)
        # print("down_1:", down_1.size())

        down_2 = self.d_conv_2(down_1)
        down_2 = self.d_bn_2(down_2)
        down_2 = self.leaky_relu(down_2)
        # print("down_2:", down_2.size())

        down_3 = self.d_conv_3(down_2)
        down_3 = self.d_bn_3(down_3)
        down_3 = self.leaky_relu(down_3)
        skip_3 = self.s_conv_3(down_3)
        # print("skip_3:", skip_3.size())

        down_4 = self.d_conv_4(down_3)
        down_4 = self.d_bn_4(down_4)
        down_4 = self.leaky_relu(down_4)
        skip_4 = self.s_conv_4(down_4)
        # print("skip_4:", skip_4.size())

        down_5 = self.d_conv_5(down_4)
        down_5 = self.d_bn_5(down_5)
        down_5 = self.leaky_relu(down_5)
        skip_5 = self.s_conv_5(down_5)
        # print("skip_5:", skip_5.size())

        down_6 = self.d_conv_6(down_5)
        down_6 = self.d_bn_6(down_6)
        down_6 = self.leaky_relu(down_6)
        # print("down_6:", down_6.size())

        up_5 = self.u_deconv_5(down_6)
        up_5 = torch.cat([up_5, skip_5], 1) # Spatial dimensions must match here
        up_5 = self.u_bn_5(up_5)
        up_5 = self.leaky_relu(up_5)
        # print("up_5:", up_5.size())

        up_4 = self.u_deconv_4(up_5)
        up_4 = torch.cat([up_4, skip_4], 1)
        up_4 = self.u_bn_4(up_4)
        up_4 = self.leaky_relu(up_4)
        # print("up_4", up_4.size())

        up_3 = self.u_deconv_3(up_4)
        up_3 = torch.cat([up_3, skip_3], 1)
        up_3 = self.u_bn_3(up_3)
        up_3 = self.leaky_relu(up_3)
        # print("up_3", up_3.size())

        up_2 = self.u_deconv_2(up_3)
        up_2 = self.u_bn_2(up_2)
        up_2 = self.leaky_relu(up_2)
        # print("up_2", up_2.size())

        up_1 = self.u_deconv_1(up_2)
        up_1 = self.u_bn_1(up_1)
        up_1 = self.leaky_relu(up_1)
        # print("up_1", up_1.size())

        out = self.out_deconv(up_1)
        out = self.out_bn(out)
        out = nn.Sigmoid()(out)
        # print("out", out.size())
        return out

# 示例使用修正后的模型
model = Hourglass()
# 使用建议的输入尺寸
x_input = torch.Tensor(np.mgrid[:512, :320]).unsqueeze(0) # 模拟输入噪音,通道数为2
y_output = model(x_input)
print("Final output size:", y_output.size())
登录后复制

注意事项:

  • 调试技巧: 在模型的forward方法中,像上面注释掉的print(tensor.size())语句一样,打印每个关键张量的尺寸是诊断这类问题的最有效方法。它能让你清晰地看到每一步操作后张量的维度变化,从而快速定位不匹配发生的位置。
  • output_padding参数: nn.ConvTranspose2d有一个output_padding参数,它可以用于微调输出尺寸,以解决1个像素的尺寸偏差。在某些情况下,当output_size = (input_size - 1) * stride - 2 * padding + kernel_size计算出的尺寸比目标尺寸小1时,将output_padding设为1可以解决问题。
  • 图像预处理: 在将实际图像输入模型之前,务必确保它们被正确地缩放或裁剪到与模型期望的输入尺寸完全一致。如果原始图像尺寸不规则,可能需要进行图像插值或填充操作。
  • 通道数一致性: 始终检查ConvTranspose2d的out_channels,确保它与skip连接的通道数之和能够匹配后续BatchNorm2d或下一个卷积层的in_channels。

总结

RuntimeError: Sizes of tensors must match except in dimension 1是PyTorch中常见的维度不匹配错误,尤其在复杂网络架构中。解决此问题需要系统性地检查模型架构,特别是卷积和反卷积层的参数配置,以及输入数据的尺寸。通过确保输入尺寸与网络下采样能力兼容,并精确调整解码器中ConvTranspose2d层的输出通道以匹配跳跃连接和后续层的期望,可以有效地消除这类张量

以上就是解决PyTorch模型中torch.cat操作的张量尺寸不匹配问题的详细内容,更多请关注php中文网其它相关文章!

最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号