
在数据分析和机器学习中,我们经常需要计算两个向量集合 a 和 b 之间所有可能的成对距离。然而,在某些特定场景下,我们可能只对其中一小部分成对距离感兴趣,例如,当一个掩码矩阵 m 指定了哪些距离是必要的时。
考虑以下一个小型示例:
import numpy as np A = np.array([[1, 2], [2, 3], [3, 4]]) # (3, 2) B = np.array([[4, 5], [5, 6], [6, 7], [7, 8], [8, 9]]) # (5, 2) M = np.array([[0, 0, 0, 1, 0], [1, 1, 0, 0, 0], [0, 0, 0, 0, 1]]) # (3, 5)
传统的做法是先计算所有成对向量的差值,然后计算它们的范数(通常是欧氏距离),最后再通过掩码矩阵 M 筛选出所需的距离。
diff = A[:,None] - B[None,:] # (3, 5, 2) distances = np.linalg.norm(diff, ord=2, axis=2) # (3, 5) masked_distances = distances * M # (3, 5)
这种方法的问题在于,即使我们只需要极少数的距离,np.linalg.norm 仍然会计算所有 A.shape[0] * B.shape[0] 个距离。当 A 和 B 的行数达到数千甚至更多时,这种不必要的计算会导致巨大的性能开销和内存浪费。特别是当掩码矩阵 M 的非零元素比例低于1%时,这种低效性更为突出。
尝试使用 np.vectorize 结合条件判断虽然可以避免计算不必要的差值,但在实际测试中,对于大型数组,其性能反而更差,因为它引入了Python级别的循环开销。
立即学习“Python免费学习笔记(深入)”;
为了解决上述效率问题,我们可以结合 Numba 的即时编译(JIT)能力和 SciPy 的稀疏矩阵(Compressed Sparse Row, CSR)结构。这种方法的核心思想是:
Numba在循环中执行自定义函数通常比调用NumPy的 np.linalg.norm 更快。因此,我们首先定义一个Numba加速的欧氏距离计算函数:
import numba as nb
import numpy as np
import scipy
import math
@nb.njit()
def euclidean_distance(vec_a, vec_b):
"""
计算两个向量之间的欧氏距离。
使用Numba进行JIT编译以提高性能。
"""
acc = 0.0
for i in range(vec_a.shape[0]):
acc += (vec_a[i] - vec_b[i]) ** 2
return math.sqrt(acc)这个函数直接计算了两个向量的欧氏距离平方和的平方根。@nb.njit() 装饰器指示 Numba 在函数首次调用时将其编译为机器码。
CSR矩阵通过三个数组来表示稀疏数据:
masked_distance_inner 函数负责遍历掩码矩阵 M,并在条件满足时计算距离并填充这三个数组:
@nb.njit()
def masked_distance_inner(data, indicies, indptr, matrix_a, matrix_b, mask):
"""
Numba JIT编译的核心函数,用于根据掩码计算并填充稀疏矩阵的数据。
参数:
data (np.ndarray): 存储非零距离值的数组。
indicies (np.ndarray): 存储非零距离值对应列索引的数组。
indptr (np.ndarray): 存储每行在data和indicies中起始位置的数组。
matrix_a (np.ndarray): 第一个向量集合。
matrix_b (np.ndarray): 第二个向量集合。
mask (np.ndarray): 布尔型掩码矩阵,指示哪些距离需要计算。
"""
write_pos = 0 # 当前写入data和indicies的位置
N, M = matrix_a.shape[0], matrix_b.shape[0]
for i in range(N): # 遍历 matrix_a 的每一行
for j in range(M): # 遍历 matrix_b 的每一行
if mask[i, j]: # 如果掩码指示该距离需要计算
# 记录距离值
data[write_pos] = euclidean_distance(matrix_a[i], matrix_b[j])
# 记录该距离值对应的列索引
indicies[write_pos] = j
write_pos += 1
# 记录当前行结束后,下一行在data和indicies中的起始位置
indptr[i + 1] = write_pos
# 断言所有预分配的内存都被使用
assert write_pos == data.shape[0]
assert write_pos == indicies.shape[0]这个函数通过双重循环遍历所有可能的 (i, j) 对。只有当 mask[i, j] 为 True 时,才会调用 euclidean_distance 计算距离,并将结果存储到 data 数组中,同时记录其列索引到 indicies 数组。indptr 数组则在每行遍历结束后更新,以正确标记下一行的起始位置。
masked_distance 函数负责初始化 data、indicies 和 indptr 数组,并调用 masked_distance_inner 完成计算,最后构建并返回 scipy.sparse.csr_matrix 对象。
def masked_distance(matrix_a, matrix_b, mask):
"""
计算并返回一个稀疏矩阵,其中包含根据掩码筛选出的成对欧氏距离。
参数:
matrix_a (np.ndarray): 第一个向量集合。
matrix_b (np.ndarray): 第二个向量集合。
mask (np.ndarray): 布尔型掩码矩阵,指示哪些距离需要计算。
返回:
scipy.sparse.csr_matrix: 包含所需距离的稀疏矩阵。
"""
N, M = matrix_a.shape[0], matrix_b.shape[0]
assert mask.shape == (N, M)
# 确保掩码是布尔类型
mask = mask != 0
# 计算稀疏矩阵将包含的非零元素总数
sparse_length = mask.sum()
# 预分配存储稀疏矩阵数据的数组
# 注意:这些数组不需要初始化为零,Numba函数会直接写入
data = np.empty(sparse_length, dtype='float64') # 存储距离值
indicies = np.empty(sparse_length, dtype='int64') # 存储列索引
indptr = np.zeros(N + 1, dtype='int64') # 存储行指针,第一个元素为0
# 调用Numba加速的核心函数进行计算和填充
masked_distance_inner(data, indicies, indptr, matrix_a, matrix_b, mask)
# 构建并返回SciPy的CSR稀疏矩阵
return scipy.sparse.csr_matrix((data, indicies, indptr), shape=(N, M))这个函数首先验证了输入掩码的形状,然后统计掩码中 True 值的数量,这决定了 data 和 indicies 数组的大小。indptr 数组的大小为 N + 1,其中 N 是 matrix_a 的行数,indptr[0] 总是 0。最后,它使用填充好的 data、indicies 和 indptr 数组以及目标矩阵的形状来构造 csr_matrix。
为了演示其效果,我们使用较大的随机数据进行测试:
# 生成较大的随机数据 A_big = np.random.rand(2000, 10) B_big = np.random.rand(4000, 10) # 生成一个非常稀疏的掩码,只有0.1%的元素为True M_big = np.random.rand(A_big.shape[0], B_big.shape[0]) < 0.001 # 使用 %timeit 魔法命令测量执行时间 # %timeit masked_distance(A_big, B_big, M_big)
在原问题提供的基准测试中,对于 A_big (2000, 10) 和 B_big (4000, 10),且 M_big 只有0.1%的元素为 True 的情况下,此方法比原始的全矩阵计算方法快了约 40倍。当向量维度更高(例如1000维)时,性能提升甚至可达 1000倍。
通过将 Numba 的JIT编译能力与 SciPy 的 CSR 稀疏矩阵结构相结合,我们成功地为大规模向量集合中稀疏的成对距离计算提供了一个高效的解决方案。这种方法避免了不必要的计算和内存分配,特别适用于当所需距离仅占总数极小比例的场景,能够带来数十倍甚至上千倍的性能提升。在处理大规模稀疏数据时,理解并应用此类优化技术对于构建高性能的数值计算系统至关重要。
以上就是优化Python中稀疏交叉差分距离计算的教程的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号