高效查找 PyTorch 张量中唯一行的索引

聖光之護
发布: 2025-10-04 12:10:15
原创
653人浏览过

高效查找 pytorch 张量中唯一行的索引

本文介绍了一种在 PyTorch 张量中高效查找每个唯一行首次出现索引的方法。通过利用 torch.unique 函数获取唯一行及其逆向索引,并结合二维张量和 torch.argmin 函数,避免了显式循环,从而提升了代码效率。文章提供了详细的代码示例和性能注意事项,帮助读者根据实际应用场景选择合适的解决方案。

在 PyTorch 中处理张量数据时,经常需要查找唯一行的索引。一种常见的方法是使用循环遍历每个唯一行,并在逆向索引中找到其首次出现的索引。然而,这种方法效率较低,尤其是在处理大型张量时。本文介绍一种更高效的方法,利用 PyTorch 的张量操作避免显式循环,从而提高代码性能。

使用 torch.unique 获取唯一行和逆向索引

首先,使用 torch.unique 函数获取张量中的唯一行、逆向索引和计数。torch.unique 函数的 return_inverse=True 参数会返回一个逆向索引张量,该张量指示原始张量中的每一行对应于唯一行张量中的哪个索引。

import torch
import numpy as np

# 示例张量
data = torch.rand(100, 5)
data[np.random.choice(100, 50, replace=False)] = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])

# 查找唯一行
u_data, inverse_indices, counts = torch.unique(data, dim=0, return_inverse=True, return_counts=True)
登录后复制

使用二维张量和 torch.argmin 查找首次出现索引

为了避免循环,我们可以创建一个二维张量 A,其维度为原始张量的行数乘以唯一行的数量。将 A 初始化为一个较大的值(例如 1000,确保大于原始张量的行数),表示“未定义的行索引”。然后,对于原始张量的每个行索引 i,将 A[i, inverse_indices[i]] 设置为 inverse_indices[i]。

纳米搜索
纳米搜索

纳米搜索:360推出的新一代AI搜索引擎

纳米搜索30
查看详情 纳米搜索
A = 1000 * torch.ones((len(data), len(u_data)), dtype=torch.long)
A[torch.arange(len(data)), inverse_indices] = inverse_indices
登录后复制

现在,考虑按列查看张量 A。第 j 列对应于第 j 个唯一行。该列的大部分值为 1000,但某些行将包含 j。该列的 argmin 就是映射到唯一行 j 的第一个原始行的索引。

unique_indices2 = torch.argmin(A, dim=0)
登录后复制

完整代码示例

import torch
import numpy as np

# 示例张量
data = torch.rand(100, 5)
data[np.random.choice(100, 50, replace=False)] = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])

# 查找唯一行
u_data, inverse_indices, counts = torch.unique(data, dim=0, return_inverse=True, return_counts=True)

# 使用循环查找首次出现索引(作为参考)
unique_indices = torch.zeros(len(u_data), dtype=torch.long)
for idx in range(len(u_data)):
    unique_indices[idx] = torch.where(inverse_indices == idx)[0][0]

# 使用二维张量和 argmin 查找首次出现索引
A = 1000 * torch.ones((len(data), len(u_data)), dtype=torch.long)
A[torch.arange(len(data)), inverse_indices] = inverse_indices
unique_indices2 = torch.argmin(A, dim=0)

# 验证结果
print(torch.allclose(unique_indices2,unique_indices))
登录后复制

性能注意事项

虽然这种方法避免了循环和 torch.where 函数,但它使用了更多的内存。argmin 函数在硬件上的速度、实际问题的维度以及对内存的重视程度都会影响其效率。在实际应用中,需要根据具体情况权衡内存使用和计算速度,选择最合适的解决方案。如果数据量较小,循环方式可能更简单易懂;如果数据量较大,且对性能要求较高,则可以考虑使用本文介绍的基于张量操作的方法。

以上就是高效查找 PyTorch 张量中唯一行的索引的详细内容,更多请关注php中文网其它相关文章!

相关标签:
最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习
PHP中文网抖音号
发现有趣的

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号