在 torch.vmap 中高效处理内部张量创建

碧海醫心
发布: 2025-10-20 14:47:07
原创
195人浏览过

在 torch.vmap 中高效处理内部张量创建

理解 torch.vmap 与内部张量创建的挑战

torch.vmap 是 PyTorch 提供的一个强大工具,它允许我们将一个处理单个样本的函数(即非批处理函数)转换为一个能够高效处理一批样本的函数,而无需手动管理批处理维度。这在编写通用代码和加速计算方面非常有用。然而,当被 vmap 向量化的函数内部需要创建新的张量,并且这些张量的形状依赖于批处理输入的形状时,就会遇到一个常见的陷阱。

考虑以下场景:我们有一个函数 polycompanion,它接收一个多项式系数张量,并计算其伴随矩阵。伴随矩阵的维度取决于多项式的次数。

import torch

poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)

def polycompanion(polynomial):
    # polynomial.shape[-1] 是多项式系数的个数,例如 [a, b, c, d] 代表 ax^3 + bx^2 + cx + d
    # 次数 deg = 系数个数 - 1 - 1 = 系数个数 - 2 (如果最后一个系数是常数项)
    deg = polynomial.shape[-1] - 2

    # 尝试创建伴随矩阵
    companion = torch.zeros((deg + 1, deg + 1), dtype=torch.float32)

    # 填充单位矩阵部分
    companion[1:, :-1] = torch.eye(deg, dtype=torch.float32)

    # 填充最后一列
    # 注意这里 polynomial[:-1] 表示除了最后一个系数以外的所有系数
    # polynomial[-1] 表示最后一个系数
    companion[:, -1] = -1. * polynomial[:-1] / polynomial[-1]
    return companion

# 尝试使用 vmap 向量化
polycompanion_vmap = torch.vmap(polycompanion)

try:
    print(polycompanion_vmap(poly_batched))
except Exception as e:
    print(f"Initial attempt failed: {e}")
登录后复制

上述代码在执行 polycompanion_vmap(poly_batched) 时会失败。原因是 polycompanion 函数内部通过 torch.zeros((deg+1, deg+1)) 创建了一个新的 companion 张量。尽管 deg 是从 polynomial(一个批处理输入)派生出来的,但 torch.zeros 本身创建的是一个普通的、非批处理的张量。当 vmap 试图对这个非批处理的 companion 张量执行批处理操作(例如,将其与从 polynomial 派生的批处理张量进行索引或赋值)时,就会出现维度不匹配或类型不兼容的问题,因为 vmap 期望所有参与运算的张量都带有批处理维度。

为什么 torch.zeros 不会自动批处理?

torch.vmap 的核心机制是跟踪批处理维度,并将操作提升到批处理层面。它能识别作为 vmap 输入的张量及其通过各种张量操作(如加法、乘法、切片等)派生出的张量,并为它们自动添加和管理批处理维度。然而,像 torch.zeros 这种从零开始创建新张量的操作,其默认行为是创建一个标准张量,不包含任何批处理维度信息。即使其形状参数 (deg+1, deg+1) 是基于批处理输入计算得出的,torch.zeros 也无法“感知”到外部的 vmap 上下文,从而无法自动生成一个 BatchedTensor。

torch.zeros_like 是一个例外,因为它基于一个已存在的张量来创建新张量。如果这个已存在的张量是 BatchedTensor,那么 torch.zeros_like 也能创建出一个 BatchedTensor。但在本例中,我们没有一个现成的 BatchedTensor 可以作为 zeros_like 的模板来创建 companion。

规避方案:预分配与外部传递

一种可行的(但不理想的)规避方法是,在调用 vmap 之前,手动创建一个带有批处理维度的 companion 张量,并将其作为函数的额外输入传递给 vmap。

def polycompanion_workaround(polynomial, companion_template):
    # 注意:这里的 deg 现在从 companion_template 的形状推断,因为它已经有了批处理维度
    deg = companion_template.shape[-1] - 1 

    # 在传入的 companion_template 上进行就地修改
    companion_template[1:, :-1] = torch.eye(deg, dtype=torch.float32)
    companion_template[:, -1] = -1. * polynomial[:-1] / polynomial[-1]
    return companion_template

polycompanion_vmap_workaround = torch.vmap(polycompanion_workaround)

# 预先创建批处理的 companion 模板
# poly_batched.shape[0] 是批次大小
# poly_batched.shape[-1]-1 是伴随矩阵的行/列维度
companion_init_shape = (poly_batched.shape[0], poly_batched.shape[-1] - 1, poly_batched.shape[-1] - 1)
pre_batched_companion = torch.zeros(companion_init_shape, dtype=torch.float32)

print("--- Workaround Output ---")
print(polycompanion_vmap_workaround(poly_batched, pre_batched_companion))
登录后复制

这种方法虽然能够正确输出结果,但存在明显缺点:

刺鸟创客
刺鸟创客

一款专业高效稳定的AI内容创作平台

刺鸟创客61
查看详情 刺鸟创客
  1. 函数签名改变:polycompanion 函数现在需要一个额外的 companion_template 参数,这破坏了其原始的、独立处理单个样本的语义。
  2. 外部依赖:在调用 vmap 之前,必须手动计算并创建具有正确批处理维度的 pre_batched_companion 张量,增加了代码的复杂性和耦合性。

推荐解决方案:利用 clone 和 concatenate

为了在 vmap 上下文中优雅地创建和填充张量,我们可以避免在非批处理的 torch.zeros 张量上进行就地修改。相反,我们将伴随矩阵视为由两部分组成:一个包含单位矩阵的左侧部分,以及一个由多项式系数计算得出的右侧(最后一列)部分。然后,我们分别构建这两部分,并使用 torch.concatenate 将它们合并。

关键在于:

  1. 静态部分:对于伴随矩阵中相对固定的部分(如单位矩阵),我们可以先在一个非批处理的 torch.zeros 张量上构建。
  2. 动态部分:对于依赖于批处理输入的部分(如最后一列),我们直接从批处理输入 polynomial 计算。
  3. 合并:使用 torch.concatenate 将这两部分合并。concatenate 是一种张量操作,vmap 能够很好地处理其批处理行为。

以下是改进后的 polycompanion 函数:

def polycompanion_optimized(polynomial):
    deg = polynomial.shape[-1] - 2

    # 1. 创建一个基础的非批处理张量来填充单位矩阵部分
    # 这是一个临时的、非批处理的张量
    base_matrix = torch.zeros((deg + 1, deg + 1), dtype=torch.float32)
    base_matrix[1:, :-1] = torch.eye(deg, dtype=torch.float32)

    # 2. 提取 base_matrix 的左侧部分,并进行克隆
    # clone() 创建了一个新的张量,虽然它仍然是非批处理的,
    # 但在 vmap 上下文中,当它与批处理张量拼接时,vmap 会正确处理
    left_part = base_matrix[:, :-1].clone()

    # 3. 计算伴随矩阵的最后一列
    # 这一部分完全从批处理输入 polynomial 派生,因此 vmap 会将其视为批处理张量
    # polynomial[:-1] 是 (deg+1,) 形状
    # polynomial[-1] 是标量
    # 结果是一个 (deg+1,) 形状的张量
    last_column_values = -1. * polynomial[:-1] / polynomial[-1]

    # 4. 扩展最后一列的维度,使其可以与 left_part 进行拼接
    # last_column_values 是 (deg+1,),我们需要将其变为 (deg+1, 1)
    last_column_reshaped = last_column_values[:, None] 

    # 5. 使用 concatenate 组合左右两部分
    # vmap 会识别 left_part 和 last_column_reshaped,并为它们在批次维度上执行拼接
    final_companion = torch.concatenate([left_part, last_column_reshaped], dim=1)

    return final_companion

polycompanion_vmap_optimized = torch.vmap(polycompanion_optimized)

print("\n--- Optimized Solution Output ---")
print(polycompanion_vmap_optimized(poly_batched))
登录后复制

输出:

tensor([[[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]],

        [[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]]])
登录后复制

这个解决方案成功地生成了批处理的伴随矩阵,同时保持了 polycompanion_optimized 函数的简洁性,使其能够独立处理单个样本,并且不需要外部预分配张量。

注意事项与最佳实践

  • 函数式编程思维:在使用 torch.vmap 时,尽量采用函数式编程的思维,即函数主要通过返回新张量来完成操作,而不是通过就地修改输入张量。这有助于 vmap 更好地跟踪张量的依赖关系和批处理维度。
  • 避免在 vmap 内部进行就地修改:除非你确切知道自己在做什么,并且只对批处理输入进行就地修改,否则应避免在 vmap 内部对非批处理张量进行就地修改。
  • clone() 的作用:在上述解决方案中,clone() 是关键。它创建了一个 base_matrix 切片的新副本。虽然 base_matrix 本身是非批处理的,但通过 clone() 得到的 left_part 可以被 concatenate 操作正确地与批处理的 last_column_reshaped 结合。
  • 维度匹配:当使用 torch.concatenate 或 torch.stack 时,确保所有参与拼接的张量在非拼接维度上形状一致。[:, None] 技巧常用于为张量添加一个维度,使其符合拼接要求。
  • 性能考量:虽然 concatenate 方案解决了功能问题,但频繁创建和拼接中间张量可能会带来一定的性能开销。对于极致性能敏感的场景,可能需要权衡 vmap 的便利性与手动批处理的优化潜力。然而,对于大多数情况,vmap 带来的代码简化和潜在加速(尤其是在支持的后端)是值得的。

总结

在 torch.vmap 中处理函数内部的张量创建是一个常见的挑战。通过理解 vmap 对批处理张量的期望,并采用 clone() 结合 torch.concatenate 的策略,我们能够优雅地构建出所需的批处理张量,而无需妥协函数的简洁性或引入复杂的外部依赖。这种方法体现了在 PyTorch 中进行高效张量操作的灵活性和强大功能,是掌握 torch.vmap 的一个重要技巧。

以上就是在 torch.vmap 中高效处理内部张量创建的详细内容,更多请关注php中文网其它相关文章!

最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习
PHP中文网抖音号
发现有趣的

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号