
本文探讨了在NumPy三维数组中高效检查子数组是否存在于另一个三维数组中的两种方法。针对传统np.isin和np.in1d在多维数组上的局限性,文章详细介绍了基于字符串转换的np.in1d方案和利用广播机制的直接比较方案,并提供了相应的代码示例和注意事项,旨在帮助读者解决复杂多维数组的查找问题。
在数据处理和科学计算中,我们经常需要判断一个数组中的元素是否存在于另一个数组中。对于NumPy的低维数组,np.isin()或np.in1d()函数能够很好地完成这项任务。然而,当处理高维数组,特别是需要检查“子数组”而非单个元素的存在性时,这些函数可能无法直接满足需求,或者产生不符合预期的结果。例如,给定两个三维NumPy数组source和values,我们希望检查source中每一个形如[x,y,z]的子数组是否在values中出现过,并返回一个布尔数组,其长度与source中待检查的子数组数量一致。
假设我们有以下两个NumPy数组:
import numpy as np source = np.array([[[0,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,1],[1,1,0],[1,1,1]]]) values = np.array([[[0,1,0],[1,0,0],[1,1,1],[1,1,1],[0,1,0]]])
其中,source的形状为(1, 7, 3),values的形状为(1, 5, 3)。我们的目标是得到一个长度为7的布尔数组,表示source中每个[*,*,*]子数组是否存在于values中。例如,source[0,2,:]即[0,1,0]在values中存在,source[0,3,:]即[1,0,0]也存在。期望的输出结果应为 [False, False, True, True, False, False, True]。
以下将介绍两种有效解决此问题的方法。
这种方法的核心思想是将每个待比较的子数组(例如[0,0,0])转换成一个唯一的字符串表示。这样,原本的子数组比较问题就转化成了字符串的比较问题,从而可以利用np.in1d()函数进行高效查找。
将子数组转换为字符串: 使用np.apply_along_axis函数,沿着指定的轴(在这里是最后一个轴,即轴2)对数组中的每个子数组应用一个转换函数。我们将每个子数组的元素转换为字符串,然后用''.join连接起来。
source_str = np.apply_along_axis(''.join, 2, source.astype(str))
values_str = np.apply_along_axis(''.join, 2, values.astype(str))经过转换后,source_str和values_str将包含表示原始子数组的字符串。例如,[0,1,0]会变成"010"。
使用 np.in1d 进行查找: 现在,source_str和values_str实际上是二维数组(形状分别为(1, 7)和(1, 5)),但它们的内容是字符串。我们可以将其进一步展平为一维数组,然后直接应用np.in1d进行查找。
result = np.in1d(source_str.flatten(), values_str.flatten()) print(result)
完整代码示例:
import numpy as np
source = np.array([[[0,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,1],[1,1,0],[1,1,1]]])
values = np.array([[[0,1,0],[1,0,0],[1,1,1],[1,1,1],[0,1,0]]])
# 将每个子数组转换为字符串
source_str = np.apply_along_axis(''.join, 2, source.astype(str))
values_str = np.apply_along_axis(''.join, 2, values.astype(str))
# 使用 np.in1d 进行查找
result_method1 = np.in1d(source_str.flatten(), values_str.flatten())
print("方法一结果:", result_method1)
# 输出: 方法一结果: [False False True True False False True]注意事项:
这种方法通过巧妙地利用NumPy的广播机制,避免了显式的循环,直接在数组维度上进行比较。它的优势在于能够保持数值类型,避免字符串转换的潜在问题,但可能在极端情况下对内存有较高要求。
调整 source 数组维度: 为了让source中的每个子数组(例如source[0,i,:])都能与values中的所有子数组(例如values[0,j,:])进行比较,我们需要调整source的维度。通过transpose(1,0,2)操作,将source从(1, N, 3)变为(N, 1, 3)。这样,source[i]就代表source中的第i个子数组,并且维度适合与values进行广播。
# 假设source的实际有效部分是source[0,:,:],values的有效部分是values[0,:,:] # 如果数组的第一个维度始终为1,可以先去除: source_flat = source[0] # 形状变为 (7, 3) values_flat = values[0] # 形状变为 (5, 3) # 调整source_flat的维度以进行广播比较 # source_flat[:, None, :] 形状变为 (7, 1, 3)
执行广播比较: 将调整后的source_flat与values_flat进行元素级相等比较。 source_flat[:, None, :] == values_flat 这将 (7, 1, 3) 的数组与 (5, 3) 的数组进行比较。NumPy的广播规则会将其扩展为 (7, 5, 3) 的布尔数组。其中,result[i, j, k] 表示source_flat[i, k]是否等于values_flat[j, k]。
聚合比较结果:
完整代码示例:
import numpy as np
source = np.array([[[0,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,1],[1,1,0],[1,1,1]]])
values = np.array([[[0,1,0],[1,0,0],[1,1,1],[1,1,1],[0,1,0]]])
# 假设数组的第一个维度始终为1,先去除
source_flat = source[0] # 形状 (7, 3)
values_flat = values[0] # 形状 (5, 3)
# 利用广播机制进行比较
# source_flat[:, None, :] 形状变为 (7, 1, 3)
# 与 values_flat (5, 3) 比较,广播为 (7, 5, 3)
# .all(axis=2) 检查每个子数组是否完全匹配,得到 (7, 5)
# .any(axis=1) 检查source中的每个子数组是否至少匹配values中的一个,得到 (7,)
result_method2 = (source_flat[:, None, :] == values_flat).all(axis=2).any(axis=1)
print("方法二结果:", result_method2)
# 输出: 方法二结果: [False False True True False False True]注意事项:
两种方法都能有效地解决在NumPy三维数组中查找子数组存在性的问题,并返回预期的布尔数组。
字符串转换法(方法一):
广播比较法(方法二):
在实际应用中,您可以根据数据的特性(类型、规模)和对性能、内存的需求来选择最合适的方法。如果内存是主要限制因素,或者数据类型复杂(例如混合类型),字符串转换可能更安全。如果数据是纯数值型且内存允许,广播比较通常能提供更优的性能。
以上就是NumPy三维数组中子数组存在性检查的高效策略的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号