
在使用`torch.vmap`进行函数向量化时,直接在被向量化的函数内部使用`torch.zeros`创建新的张量并期望其自动获得批处理维度是一个常见挑战。本文将深入探讨这一问题,并提供一种优雅的解决方案:通过结合`clone()`和`torch.concatenate`,可以有效地在`vmap`环境中创建和填充具有正确批处理维度的张量,从而避免手动传递预先创建的批处理张量,实现代码的简洁与高效。
torch.vmap是PyTorch中一个强大的工具,它允许用户对批量输入高效地应用一个单样本函数,而无需手动编写循环或调整张量维度。然而,当被向量化的函数需要在内部创建新的张量时,一个常见的陷阱是这些新创建的张量并不会自动继承批处理维度。
考虑一个计算多项式伴随矩阵的函数polycompanion。这个函数需要根据输入多项式polynomial的维度创建一个新的零矩阵companion,然后填充其部分内容。
import torch
poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)
def polycompanion(polynomial):
    # 计算伴随矩阵的维度
    deg = polynomial.shape[-1] - 2
    # 创建一个 (deg+1, deg+1) 的零矩阵
    companion = torch.zeros((deg + 1, deg + 1), dtype=torch.float32)
    # 填充单位矩阵部分
    companion[1:, :-1] = torch.eye(deg, dtype=torch.float32)
    # 填充最后一列,这部分依赖于输入多项式
    companion[:, -1] = -1. * polynomial[:-1] / polynomial[-1]
    return companion
# 尝试使用 vmap 向量化该函数
polycompanion_vmap = torch.vmap(polycompanion)
# 预期会遇到问题,因为 companion 不是 BatchedTensor
# print(polycompanion_vmap(poly_batched))
# 上述代码会因 vmap 无法处理非 BatchedTensor 的原地操作而失败在上述代码中,torch.vmap在执行polycompanion时,polynomial是一个BatchedTensor。然而,companion = torch.zeros((deg + 1, deg + 1))创建的companion张量并不是BatchedTensor。当尝试对companion进行原地修改,特别是当修改操作涉及polynomial(一个BatchedTensor)时,vmap无法正确地跟踪和应用批处理语义,导致运行时错误。
为了规避这个问题,一种常见的(但不推荐的)做法是预先在vmap外部创建批处理的零张量,并将其作为参数传递给被向量化的函数。
import torch
poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)
def polycompanion_workaround(polynomial, companion_template):
    # 注意:这里的 deg 需要根据 companion_template 的形状来推断,或者与 polynomial 保持一致
    # 为了简化,我们假设 companion_template 已经有正确的形状
    deg = companion_template.shape[-1] - 1 # 假设 companion_template 已经是 (deg+1, deg+1)
    # 在 template 上进行原地操作
    companion_template[1:, :-1] = torch.eye(deg, dtype=torch.float32)
    companion_template[:, -1] = -1. * polynomial[:-1] / polynomial[-1]
    return companion_template
polycompanion_vmap_workaround = torch.vmap(polycompanion_workaround)
# 预先创建批处理的零张量
batch_size = poly_batched.shape[0]
companion_dim = poly_batched.shape[-1] - 1 # (deg+1)
initial_companion = torch.zeros(batch_size, companion_dim, companion_dim, dtype=torch.float32)
# 传递预创建的批处理张量
output_workaround = polycompanion_vmap_workaround(poly_batched, initial_companion)
print("Workaround Output:")
print(output_workaround)输出:
Workaround Output:
tensor([[[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]],
        [[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]]])这种方法虽然能工作,但它破坏了函数的封装性,使得函数签名的设计变得复杂,且在函数内部无法动态决定新张量的批处理大小,不够灵活。
解决此问题的关键在于,对于需要批处理的张量,我们必须确保其批处理维度在vmap的上下文中是明确的。如果一个张量的一部分内容依赖于批处理输入,而另一部分是固定的,我们可以将它们分别处理,然后合并。
核心思路是:
import torch
poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)
def polycompanion_refined(polynomial):
    deg = polynomial.shape[-1] - 2
    # 1. 创建一个非批处理的零矩阵作为基础
    companion_base = torch.zeros((deg + 1, deg + 1), dtype=torch.float32)
    # 2. 填充单位矩阵部分(这部分是固定的,不依赖于批处理)
    # 注意:这里我们只填充除了最后一列之外的部分
    companion_base[1:, :-1] = torch.eye(deg, dtype=torch.float32)
    # 3. 计算最后一列,这部分是依赖于 polynomial (BatchedTensor) 的,因此会是 BatchedTensor
    last_column_batched = -1. * polynomial[:-1] / polynomial[-1]
    # 4. 准备合并:
    #    - companion_base[:, :-1] 是非批处理的,需要 clone 以便后续操作。
    #      clone() 确保 vmap 可以对每个批次独立处理这个副本。
    #    - last_column_batched 是一个一维的 BatchedTensor,形状为 (batch_size, deg+1)。
    #      为了与 companion_base[:, :-1] (形状为 (deg+1, deg)) 合并,
    #      需要将其扩展为 (batch_size, deg+1, 1) 的形状,通过 [:, None] 实现。
    _companion = torch.concatenate([
        companion_base[:, :-1].clone(), # 克隆非批处理的左侧部分
        last_column_batched[:, None]    # 批处理的右侧列,添加一个维度使其可合并
    ], dim=1) # 沿着列维度合并
    return _companion
polycompanion_vmap_refined = torch.vmap(polycompanion_refined)
output_refined = polycompanion_vmap_refined(poly_batched)
print("\nRefined Solution Output:")
print(output_refined)输出:
Refined Solution Output:
tensor([[[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]],
        [[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]]])通过这种clone()和torch.concatenate的组合,我们能够在torch.vmap的上下文中,在函数内部灵活且优雅地创建和填充新的批处理张量,从而保持代码的简洁性和功能性,避免了不必要的外部参数传递。这种模式对于在vmap函数中构建复杂张量结构非常有用。
以上就是在torch.vmap中高效创建与操作批处理张量的详细内容,更多请关注php中文网其它相关文章!
                        
                        每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
                
                                
                                
                                
                                
                                
                                Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号