
在数据处理和机器学习任务中,我们经常需要处理包含重复数据的张量(tensor)。当需要识别张量中所有唯一行,并进一步获取这些唯一行在原始张量中首次出现的索引时,一个常见的挑战是效率问题。
PyTorch提供了torch.unique函数来方便地找出张量中的唯一行及其相关信息。例如,torch.unique(data, dim=0, return_inverse=True)会返回唯一行、以及一个inverse_indices张量,该张量将原始张量中的每个行映射到其对应的唯一行索引。
然而,要根据inverse_indices找出每个唯一行在原始张量中首次出现的索引,一个直观但效率低下的方法是使用Python循环:
import torch
import numpy as np
# 示例张量
data = torch.rand(100, 5)
# 引入一些重复行
data[np.random.choice(100, 50, replace=False)] = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
# 查找唯一行及其逆索引
u_data, inverse_indices, counts = torch.unique(data, dim=0, return_inverse=True, return_counts=True)
# 传统方法:通过循环查找每个唯一行的首次出现索引
# 这个循环是效率瓶颈所在
unique_indices = torch.zeros(len(u_data), dtype=torch.long)
for idx in range(len(u_data)):
unique_indices[idx] = torch.where(inverse_indices == idx)[0][0]
print("传统方法得到的首次出现索引:", unique_indices)上述代码中,for循环遍历每个唯一行的索引idx,然后使用torch.where查找inverse_indices中所有等于idx的位置,并取第一个位置作为首次出现的索引。这种逐个查找的循环方式,尤其是在处理大型张量时,会导致显著的性能开销,因为它涉及多次Python循环迭代和张量条件查找操作。
为了避免上述低效的循环,我们可以采用一种更符合PyTorch风格的向量化方法。其核心思想是构建一个辅助的二维张量,巧妙地利用其结构,并通过torch.argmin操作来高效地找出首次出现的索引。
核心思路:
示例代码:
import torch
import numpy as np
# 示例张量 (与问题部分相同)
data = torch.rand(100, 5)
data[np.random.choice(100, 50, replace=False)] = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
# 查找唯一行及其逆索引
u_data, inverse_indices, counts = torch.unique(data, dim=0, return_inverse=True, return_counts=True)
# 优化方法:基于二维张量和argmin
num_original_rows = len(data)
num_unique_rows = len(u_data)
# 1. 创建辅助张量A,并用一个大值(如1000,确保大于任何可能的行索引)初始化
# dtype应为long以匹配索引类型
placeholder_value = num_original_rows + 100 # 确保占位符大于最大行索引
A = placeholder_value * torch.ones((num_original_rows, num_unique_rows), dtype=torch.long)
# 2. 填充张量A
# A[i, inverse_indices[i]] = i
# torch.arange(num_original_rows) 生成 [0, 1, ..., num_original_rows-1]
# inverse_indices 提供了每个原始行对应的唯一行索引
# 这样,A[i, j] = i 当且仅当原始行 i 属于唯一行组 j
A[torch.arange(num_original_rows), inverse_indices] = torch.arange(num_original_rows)
# 3. 使用argmin查找首次出现索引
# 沿dim=0(列方向)查找最小值,即找到每个唯一行组的最小原始行索引
unique_indices_optimized = torch.argmin(A, dim=0)
print("优化方法得到的首次出现索引:", unique_indices_optimized)
# 验证两种方法结果是否一致
# (为了验证,这里重新计算了传统方法的结果)
unique_indices_traditional = torch.zeros(len(u_data), dtype=torch.long)
for idx in range(len(u_data)):
unique_indices_traditional[idx] = torch.where(inverse_indices == idx)[0][0]
print("两种方法结果是否一致:", torch.allclose(unique_indices_optimized, unique_indices_traditional))代码解释:
总结:
在选择方法时,需要根据实际应用场景进行权衡:
总而言之,通过巧妙地利用PyTorch的张量操作,我们可以将复杂的循环逻辑转化为高效的向量化计算,从而在处理数据时获得更好的性能。
以上就是PyTorch张量唯一行首次出现索引的高效查找方法的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号