
本文深入探讨了 pytorch 中 `add_()` 等原地操作在广播机制下引发 `runtimeerror` 的原因。核心在于原地操作试图直接修改原始张量,而当广播结果的形状与原始张量形状不匹配时,无法在现有内存空间中完成操作。文章通过对比 numpy 和 pytorch 的行为,并提供正确的使用示例,帮助读者理解并避免此类常见错误。
PyTorch 的广播(Broadcasting)机制允许不同形状的张量在某些条件下执行算术运算,而无需显式地复制数据以使它们具有相同的形状。其基本规则如下:
例如,一个形状为 (1, 3, 1) 的张量与一个形状为 (3, 1, 7) 的张量进行加法运算时,根据广播规则,它们可以扩展为 (3, 3, 7) 的形状进行逐元素操作。
在 PyTorch 中,许多操作都提供两种形式:
理解这两种操作的区别是解决本文所讨论问题的关键。
当尝试执行以下 PyTorch 代码时,会遇到 RuntimeError:
import torch x = torch.empty(1, 3, 1) y = torch.empty(3, 1, 7) # 尝试原地加法操作 (x.add_(y)).size()
报错信息如下: RuntimeError: output with shape [1, 3, 1] doesn't match the broadcast shape [3, 3, 7]
这个错误清楚地指出了问题所在:x 的原始形状 [1, 3, 1] 与广播后的预期形状 [3, 3, 7] 不匹配。
根本原因在于:
PyTorch 官方文档对此有明确说明:“对于原地操作,原地张量必须能够与另一个张量进行广播,并且原地张量的存储必须足够大以存储结果。如果原地张量的存储不够大,则会引发错误。”
为了更好地理解 PyTorch 的行为,我们可以对比 NumPy 的相同操作:
import numpy as np x_np = np.empty((1, 3, 1)) y_np = np.empty((3, 1, 7)) # NumPy 的加法操作 (x_np + y_np).shape # Output: (3, 3, 7)
NumPy 能够正确执行并生成 (3, 3, 7) 形状的结果。这是因为 NumPy 的 + 运算符默认执行的是非原地操作。它会创建一个全新的数组来存储 x_np 和 y_np 广播后的结果,而不是尝试修改 x_np 的原始内存。因此,NumPy 不会遇到 PyTorch 中因原地修改导致的内存/形状不匹配问题。
要避免 PyTorch 中 add_() 等原地操作在广播时引发的 RuntimeError,应采取以下策略:
优先使用非原地操作: 这是最推荐的做法。当需要进行广播运算时,使用非原地操作(如 + 运算符或 torch.add() 函数),它们会返回一个新的张量,而不会尝试修改原始张量。
import torch x = torch.empty(1, 3, 1) y = torch.empty(3, 1, 7) # 使用非原地操作 result = x + y print(result.size()) # Output: torch.Size([3, 3, 7]) # 或者使用 torch.add() result_add = torch.add(x, y) print(result_add.size()) # Output: torch.Size([3, 3, 7])
理解原地操作的适用场景: 原地操作通常用于:
PyTorch 的原地操作(以 _ 结尾的函数)提供了内存优化的可能性,但它们也引入了额外的限制。在进行涉及广播的运算时,务必注意以下几点:
理解 PyTorch 中原地操作与非原地操作之间的细微差别,以及它们与广播机制的交互方式,对于编写健壮且高效的 PyTorch 代码至关重要。
以上就是PyTorch 广播机制下的原地操作陷阱:add_() 形状不匹配错误深度解析的详细内容,更多请关注php中文网其它相关文章!
 
                        
                        每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
 
                Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号