
torch.vmap 是 PyTorch 提供的一个强大工具,它允许我们将一个处理单个样本的函数(即非批处理函数)转换为一个能够高效处理一批样本的函数,而无需手动管理批处理维度。这在编写通用代码和加速计算方面非常有用。然而,当被 vmap 向量化的函数内部需要创建新的张量,并且这些张量的形状依赖于批处理输入的形状时,就会遇到一个常见的陷阱。
考虑以下场景:我们有一个函数 polycompanion,它接收一个多项式系数张量,并计算其伴随矩阵。伴随矩阵的维度取决于多项式的次数。
import torch
poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)
def polycompanion(polynomial):
# polynomial.shape[-1] 是多项式系数的个数,例如 [a, b, c, d] 代表 ax^3 + bx^2 + cx + d
# 次数 deg = 系数个数 - 1 - 1 = 系数个数 - 2 (如果最后一个系数是常数项)
deg = polynomial.shape[-1] - 2
# 尝试创建伴随矩阵
companion = torch.zeros((deg + 1, deg + 1), dtype=torch.float32)
# 填充单位矩阵部分
companion[1:, :-1] = torch.eye(deg, dtype=torch.float32)
# 填充最后一列
# 注意这里 polynomial[:-1] 表示除了最后一个系数以外的所有系数
# polynomial[-1] 表示最后一个系数
companion[:, -1] = -1. * polynomial[:-1] / polynomial[-1]
return companion
# 尝试使用 vmap 向量化
polycompanion_vmap = torch.vmap(polycompanion)
try:
print(polycompanion_vmap(poly_batched))
except Exception as e:
print(f"Initial attempt failed: {e}")上述代码在执行 polycompanion_vmap(poly_batched) 时会失败。原因是 polycompanion 函数内部通过 torch.zeros((deg+1, deg+1)) 创建了一个新的 companion 张量。尽管 deg 是从 polynomial(一个批处理输入)派生出来的,但 torch.zeros 本身创建的是一个普通的、非批处理的张量。当 vmap 试图对这个非批处理的 companion 张量执行批处理操作(例如,将其与从 polynomial 派生的批处理张量进行索引或赋值)时,就会出现维度不匹配或类型不兼容的问题,因为 vmap 期望所有参与运算的张量都带有批处理维度。
torch.vmap 的核心机制是跟踪批处理维度,并将操作提升到批处理层面。它能识别作为 vmap 输入的张量及其通过各种张量操作(如加法、乘法、切片等)派生出的张量,并为它们自动添加和管理批处理维度。然而,像 torch.zeros 这种从零开始创建新张量的操作,其默认行为是创建一个标准张量,不包含任何批处理维度信息。即使其形状参数 (deg+1, deg+1) 是基于批处理输入计算得出的,torch.zeros 也无法“感知”到外部的 vmap 上下文,从而无法自动生成一个 BatchedTensor。
torch.zeros_like 是一个例外,因为它基于一个已存在的张量来创建新张量。如果这个已存在的张量是 BatchedTensor,那么 torch.zeros_like 也能创建出一个 BatchedTensor。但在本例中,我们没有一个现成的 BatchedTensor 可以作为 zeros_like 的模板来创建 companion。
一种可行的(但不理想的)规避方法是,在调用 vmap 之前,手动创建一个带有批处理维度的 companion 张量,并将其作为函数的额外输入传递给 vmap。
def polycompanion_workaround(polynomial, companion_template):
# 注意:这里的 deg 现在从 companion_template 的形状推断,因为它已经有了批处理维度
deg = companion_template.shape[-1] - 1
# 在传入的 companion_template 上进行就地修改
companion_template[1:, :-1] = torch.eye(deg, dtype=torch.float32)
companion_template[:, -1] = -1. * polynomial[:-1] / polynomial[-1]
return companion_template
polycompanion_vmap_workaround = torch.vmap(polycompanion_workaround)
# 预先创建批处理的 companion 模板
# poly_batched.shape[0] 是批次大小
# poly_batched.shape[-1]-1 是伴随矩阵的行/列维度
companion_init_shape = (poly_batched.shape[0], poly_batched.shape[-1] - 1, poly_batched.shape[-1] - 1)
pre_batched_companion = torch.zeros(companion_init_shape, dtype=torch.float32)
print("--- Workaround Output ---")
print(polycompanion_vmap_workaround(poly_batched, pre_batched_companion))这种方法虽然能够正确输出结果,但存在明显缺点:
为了在 vmap 上下文中优雅地创建和填充张量,我们可以避免在非批处理的 torch.zeros 张量上进行就地修改。相反,我们将伴随矩阵视为由两部分组成:一个包含单位矩阵的左侧部分,以及一个由多项式系数计算得出的右侧(最后一列)部分。然后,我们分别构建这两部分,并使用 torch.concatenate 将它们合并。
关键在于:
以下是改进后的 polycompanion 函数:
def polycompanion_optimized(polynomial):
deg = polynomial.shape[-1] - 2
# 1. 创建一个基础的非批处理张量来填充单位矩阵部分
# 这是一个临时的、非批处理的张量
base_matrix = torch.zeros((deg + 1, deg + 1), dtype=torch.float32)
base_matrix[1:, :-1] = torch.eye(deg, dtype=torch.float32)
# 2. 提取 base_matrix 的左侧部分,并进行克隆
# clone() 创建了一个新的张量,虽然它仍然是非批处理的,
# 但在 vmap 上下文中,当它与批处理张量拼接时,vmap 会正确处理
left_part = base_matrix[:, :-1].clone()
# 3. 计算伴随矩阵的最后一列
# 这一部分完全从批处理输入 polynomial 派生,因此 vmap 会将其视为批处理张量
# polynomial[:-1] 是 (deg+1,) 形状
# polynomial[-1] 是标量
# 结果是一个 (deg+1,) 形状的张量
last_column_values = -1. * polynomial[:-1] / polynomial[-1]
# 4. 扩展最后一列的维度,使其可以与 left_part 进行拼接
# last_column_values 是 (deg+1,),我们需要将其变为 (deg+1, 1)
last_column_reshaped = last_column_values[:, None]
# 5. 使用 concatenate 组合左右两部分
# vmap 会识别 left_part 和 last_column_reshaped,并为它们在批次维度上执行拼接
final_companion = torch.concatenate([left_part, last_column_reshaped], dim=1)
return final_companion
polycompanion_vmap_optimized = torch.vmap(polycompanion_optimized)
print("\n--- Optimized Solution Output ---")
print(polycompanion_vmap_optimized(poly_batched))输出:
tensor([[[ 0.0000, 0.0000, -0.2500],
[ 1.0000, 0.0000, -0.5000],
[ 0.0000, 1.0000, -0.7500]],
[[ 0.0000, 0.0000, -0.2500],
[ 1.0000, 0.0000, -0.5000],
[ 0.0000, 1.0000, -0.7500]]])这个解决方案成功地生成了批处理的伴随矩阵,同时保持了 polycompanion_optimized 函数的简洁性,使其能够独立处理单个样本,并且不需要外部预分配张量。
在 torch.vmap 中处理函数内部的张量创建是一个常见的挑战。通过理解 vmap 对批处理张量的期望,并采用 clone() 结合 torch.concatenate 的策略,我们能够优雅地构建出所需的批处理张量,而无需妥协函数的简洁性或引入复杂的外部依赖。这种方法体现了在 PyTorch 中进行高效张量操作的灵活性和强大功能,是掌握 torch.vmap 的一个重要技巧。
以上就是在 torch.vmap 中高效处理内部张量创建的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号